chore: Implement code review suggestions and format code.

This commit is contained in:
Elisiário Couto
2025-09-03 21:11:19 +01:00
committed by Elisiário Couto
parent 47164e8546
commit de3da84dff
42 changed files with 1144 additions and 966 deletions

View File

@@ -1,32 +0,0 @@
FROM python:3.13-alpine AS builder
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
WORKDIR /app
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=uv.lock,target=uv.lock \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
uv sync --frozen --no-install-project --no-editable
COPY . /app
RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --frozen --no-editable --no-group dev
FROM python:3.13-alpine
LABEL org.opencontainers.image.source="https://github.com/elisiariocouto/leggen"
LABEL org.opencontainers.image.authors="Elisiário Couto <elisiario@couto.io>"
LABEL org.opencontainers.image.licenses="MIT"
LABEL org.opencontainers.image.title="leggend"
LABEL org.opencontainers.image.description="Leggen API service"
LABEL org.opencontainers.image.url="https://github.com/elisiariocouto/leggen"
WORKDIR /app
ENV PATH="/app/.venv/bin:$PATH"
COPY --from=builder --chown=app:app /app/.venv /app/.venv
EXPOSE 8000
ENTRYPOINT ["/app/.venv/bin/leggend"]

View File

@@ -3,7 +3,6 @@ services:
leggend: leggend:
build: build:
context: . context: .
dockerfile: Dockerfile.leggend
restart: "unless-stopped" restart: "unless-stopped"
ports: ports:
- "127.0.0.1:8000:8000" - "127.0.0.1:8000:8000"
@@ -18,20 +17,6 @@ services:
timeout: 10s timeout: 10s
retries: 3 retries: 3
# CLI for one-off operations (uses leggend API)
leggen:
image: elisiariocouto/leggen:latest
command: sync --wait
restart: "no"
volumes:
- "./leggen:/root/.config/leggen"
- "./db:/app"
environment:
- LEGGEND_API_URL=http://leggend:8000
depends_on:
leggend:
condition: service_healthy
nocodb: nocodb:
image: nocodb/nocodb:latest image: nocodb/nocodb:latest
restart: "unless-stopped" restart: "unless-stopped"

View File

@@ -8,19 +8,20 @@ from leggen.utils.text import error
class LeggendAPIClient: class LeggendAPIClient:
"""Client for communicating with the leggend FastAPI service""" """Client for communicating with the leggend FastAPI service"""
def __init__(self, base_url: Optional[str] = None): def __init__(self, base_url: Optional[str] = None):
self.base_url = base_url or os.environ.get("LEGGEND_API_URL", "http://localhost:8000") self.base_url = base_url or os.environ.get(
"LEGGEND_API_URL", "http://localhost:8000"
)
self.session = requests.Session() self.session = requests.Session()
self.session.headers.update({ self.session.headers.update(
"Content-Type": "application/json", {"Content-Type": "application/json", "Accept": "application/json"}
"Accept": "application/json" )
})
def _make_request(self, method: str, endpoint: str, **kwargs) -> Dict[str, Any]: def _make_request(self, method: str, endpoint: str, **kwargs) -> Dict[str, Any]:
"""Make HTTP request to the API""" """Make HTTP request to the API"""
url = urljoin(self.base_url, endpoint) url = urljoin(self.base_url, endpoint)
try: try:
response = self.session.request(method, url, **kwargs) response = self.session.request(method, url, **kwargs)
response.raise_for_status() response.raise_for_status()
@@ -53,15 +54,19 @@ class LeggendAPIClient:
# Bank endpoints # Bank endpoints
def get_institutions(self, country: str = "PT") -> List[Dict[str, Any]]: def get_institutions(self, country: str = "PT") -> List[Dict[str, Any]]:
"""Get bank institutions for a country""" """Get bank institutions for a country"""
response = self._make_request("GET", "/api/v1/banks/institutions", params={"country": country}) response = self._make_request(
"GET", "/api/v1/banks/institutions", params={"country": country}
)
return response.get("data", []) return response.get("data", [])
def connect_to_bank(self, institution_id: str, redirect_url: str = "http://localhost:8000/") -> Dict[str, Any]: def connect_to_bank(
self, institution_id: str, redirect_url: str = "http://localhost:8000/"
) -> Dict[str, Any]:
"""Connect to a bank""" """Connect to a bank"""
response = self._make_request( response = self._make_request(
"POST", "POST",
"/api/v1/banks/connect", "/api/v1/banks/connect",
json={"institution_id": institution_id, "redirect_url": redirect_url} json={"institution_id": institution_id, "redirect_url": redirect_url},
) )
return response.get("data", {}) return response.get("data", {})
@@ -91,31 +96,39 @@ class LeggendAPIClient:
response = self._make_request("GET", f"/api/v1/accounts/{account_id}/balances") response = self._make_request("GET", f"/api/v1/accounts/{account_id}/balances")
return response.get("data", []) return response.get("data", [])
def get_account_transactions(self, account_id: str, limit: int = 100, summary_only: bool = False) -> List[Dict[str, Any]]: def get_account_transactions(
self, account_id: str, limit: int = 100, summary_only: bool = False
) -> List[Dict[str, Any]]:
"""Get account transactions""" """Get account transactions"""
response = self._make_request( response = self._make_request(
"GET", "GET",
f"/api/v1/accounts/{account_id}/transactions", f"/api/v1/accounts/{account_id}/transactions",
params={"limit": limit, "summary_only": summary_only} params={"limit": limit, "summary_only": summary_only},
) )
return response.get("data", []) return response.get("data", [])
# Transaction endpoints # Transaction endpoints
def get_all_transactions(self, limit: int = 100, summary_only: bool = True, **filters) -> List[Dict[str, Any]]: def get_all_transactions(
self, limit: int = 100, summary_only: bool = True, **filters
) -> List[Dict[str, Any]]:
"""Get all transactions with optional filters""" """Get all transactions with optional filters"""
params = {"limit": limit, "summary_only": summary_only} params = {"limit": limit, "summary_only": summary_only}
params.update(filters) params.update(filters)
response = self._make_request("GET", "/api/v1/transactions", params=params) response = self._make_request("GET", "/api/v1/transactions", params=params)
return response.get("data", []) return response.get("data", [])
def get_transaction_stats(self, days: int = 30, account_id: Optional[str] = None) -> Dict[str, Any]: def get_transaction_stats(
self, days: int = 30, account_id: Optional[str] = None
) -> Dict[str, Any]:
"""Get transaction statistics""" """Get transaction statistics"""
params = {"days": days} params = {"days": days}
if account_id: if account_id:
params["account_id"] = account_id params["account_id"] = account_id
response = self._make_request("GET", "/api/v1/transactions/stats", params=params) response = self._make_request(
"GET", "/api/v1/transactions/stats", params=params
)
return response.get("data", {}) return response.get("data", {})
# Sync endpoints # Sync endpoints
@@ -124,21 +137,25 @@ class LeggendAPIClient:
response = self._make_request("GET", "/api/v1/sync/status") response = self._make_request("GET", "/api/v1/sync/status")
return response.get("data", {}) return response.get("data", {})
def trigger_sync(self, account_ids: Optional[List[str]] = None, force: bool = False) -> Dict[str, Any]: def trigger_sync(
self, account_ids: Optional[List[str]] = None, force: bool = False
) -> Dict[str, Any]:
"""Trigger a sync""" """Trigger a sync"""
data = {"force": force} data = {"force": force}
if account_ids: if account_ids:
data["account_ids"] = account_ids data["account_ids"] = account_ids
response = self._make_request("POST", "/api/v1/sync", json=data) response = self._make_request("POST", "/api/v1/sync", json=data)
return response.get("data", {}) return response.get("data", {})
def sync_now(self, account_ids: Optional[List[str]] = None, force: bool = False) -> Dict[str, Any]: def sync_now(
self, account_ids: Optional[List[str]] = None, force: bool = False
) -> Dict[str, Any]:
"""Run sync synchronously""" """Run sync synchronously"""
data = {"force": force} data = {"force": force}
if account_ids: if account_ids:
data["account_ids"] = account_ids data["account_ids"] = account_ids
response = self._make_request("POST", "/api/v1/sync/now", json=data) response = self._make_request("POST", "/api/v1/sync/now", json=data)
return response.get("data", {}) return response.get("data", {})
@@ -147,11 +164,17 @@ class LeggendAPIClient:
response = self._make_request("GET", "/api/v1/sync/scheduler") response = self._make_request("GET", "/api/v1/sync/scheduler")
return response.get("data", {}) return response.get("data", {})
def update_scheduler_config(self, enabled: bool = True, hour: int = 3, minute: int = 0, cron: Optional[str] = None) -> Dict[str, Any]: def update_scheduler_config(
self,
enabled: bool = True,
hour: int = 3,
minute: int = 0,
cron: Optional[str] = None,
) -> Dict[str, Any]:
"""Update scheduler configuration""" """Update scheduler configuration"""
data = {"enabled": enabled, "hour": hour, "minute": minute} data = {"enabled": enabled, "hour": hour, "minute": minute}
if cron: if cron:
data["cron"] = cron data["cron"] = cron
response = self._make_request("PUT", "/api/v1/sync/scheduler", json=data) response = self._make_request("PUT", "/api/v1/sync/scheduler", json=data)
return response.get("data", {}) return response.get("data", {})

View File

@@ -12,10 +12,12 @@ def balances(ctx: click.Context):
List balances of all connected accounts List balances of all connected accounts
""" """
api_client = LeggendAPIClient(ctx.obj.get("api_url")) api_client = LeggendAPIClient(ctx.obj.get("api_url"))
# Check if leggend service is available # Check if leggend service is available
if not api_client.health_check(): if not api_client.health_check():
click.echo("Error: Cannot connect to leggend service. Please ensure it's running.") click.echo(
"Error: Cannot connect to leggend service. Please ensure it's running."
)
return return
accounts = api_client.get_accounts() accounts = api_client.get_accounts()
@@ -24,11 +26,7 @@ def balances(ctx: click.Context):
for account in accounts: for account in accounts:
for balance in account.get("balances", []): for balance in account.get("balances", []):
amount = round(float(balance["amount"]), 2) amount = round(float(balance["amount"]), 2)
symbol = ( symbol = "" if balance["currency"] == "EUR" else f" {balance['currency']}"
""
if balance["currency"] == "EUR"
else f" {balance['currency']}"
)
amount_str = f"{amount}{symbol}" amount_str = f"{amount}{symbol}"
date = ( date = (
datefmt(balance.get("last_change_date")) datefmt(balance.get("last_change_date"))

View File

@@ -1,36 +0,0 @@
import os
import click
from leggen.main import cli
cmd_folder = os.path.abspath(os.path.dirname(__file__))
class BankGroup(click.Group):
def list_commands(self, ctx):
rv = []
for filename in os.listdir(cmd_folder):
if filename.endswith(".py") and not filename.startswith("__init__"):
if filename == "list_banks.py":
rv.append("list")
else:
rv.append(filename[:-3])
rv.sort()
return rv
def get_command(self, ctx, name):
try:
if name == "list":
name = "list_banks"
mod = __import__(f"leggen.commands.bank.{name}", None, None, [name])
except ImportError:
return
return getattr(mod, name)
@cli.group(cls=BankGroup)
@click.pass_context
def bank(ctx):
"""Manage banks connections"""
return

View File

@@ -13,30 +13,32 @@ def add(ctx):
Connect to a bank Connect to a bank
""" """
api_client = LeggendAPIClient(ctx.obj.get("api_url")) api_client = LeggendAPIClient(ctx.obj.get("api_url"))
# Check if leggend service is available # Check if leggend service is available
if not api_client.health_check(): if not api_client.health_check():
click.echo("Error: Cannot connect to leggend service. Please ensure it's running.") click.echo(
"Error: Cannot connect to leggend service. Please ensure it's running."
)
return return
try: try:
# Get supported countries # Get supported countries
countries = api_client.get_supported_countries() countries = api_client.get_supported_countries()
country_codes = [c["code"] for c in countries] country_codes = [c["code"] for c in countries]
country = click.prompt( country = click.prompt(
"Bank Country", "Bank Country",
type=click.Choice(country_codes, case_sensitive=True), type=click.Choice(country_codes, case_sensitive=True),
default="PT", default="PT",
) )
info(f"Getting bank list for country: {country}") info(f"Getting bank list for country: {country}")
banks = api_client.get_institutions(country) banks = api_client.get_institutions(country)
if not banks: if not banks:
warning(f"No banks available for country {country}") warning(f"No banks available for country {country}")
return return
filtered_banks = [ filtered_banks = [
{ {
"id": bank["id"], "id": bank["id"],
@@ -46,14 +48,14 @@ def add(ctx):
for bank in banks for bank in banks
] ]
print_table(filtered_banks) print_table(filtered_banks)
allowed_ids = [str(bank["id"]) for bank in banks] allowed_ids = [str(bank["id"]) for bank in banks]
bank_id = click.prompt("Bank ID", type=click.Choice(allowed_ids)) bank_id = click.prompt("Bank ID", type=click.Choice(allowed_ids))
# Show bank details # Show bank details
selected_bank = next(bank for bank in banks if bank["id"] == bank_id) selected_bank = next(bank for bank in banks if bank["id"] == bank_id)
info(f"Selected bank: {selected_bank['name']}") info(f"Selected bank: {selected_bank['name']}")
click.confirm("Do you agree to connect to this bank?", abort=True) click.confirm("Do you agree to connect to this bank?", abort=True)
info(f"Connecting to bank with ID: {bank_id}") info(f"Connecting to bank with ID: {bank_id}")
@@ -65,11 +67,15 @@ def add(ctx):
save_file(f"req_{result['id']}.json", result) save_file(f"req_{result['id']}.json", result)
success("Bank connection request created successfully!") success("Bank connection request created successfully!")
warning(f"Please open the following URL in your browser to complete the authorization:") warning(
f"Please open the following URL in your browser to complete the authorization:"
)
click.echo(f"\n{result['link']}\n") click.echo(f"\n{result['link']}\n")
info(f"Requisition ID: {result['id']}") info(f"Requisition ID: {result['id']}")
info("After completing the authorization, you can check the connection status with 'leggen status'") info(
"After completing the authorization, you can check the connection status with 'leggen status'"
)
except Exception as e: except Exception as e:
click.echo(f"Error: Failed to connect to bank: {str(e)}") click.echo(f"Error: Failed to connect to bank: {str(e)}")

View File

@@ -12,10 +12,12 @@ def status(ctx: click.Context):
List all connected banks and their status List all connected banks and their status
""" """
api_client = LeggendAPIClient(ctx.obj.get("api_url")) api_client = LeggendAPIClient(ctx.obj.get("api_url"))
# Check if leggend service is available # Check if leggend service is available
if not api_client.health_check(): if not api_client.health_check():
click.echo("Error: Cannot connect to leggend service. Please ensure it's running.") click.echo(
"Error: Cannot connect to leggend service. Please ensure it's running."
)
return return
# Get bank connection status # Get bank connection status

View File

@@ -6,15 +6,15 @@ from leggen.utils.text import error, info, success
@cli.command() @cli.command()
@click.option('--wait', is_flag=True, help='Wait for sync to complete (synchronous)') @click.option("--wait", is_flag=True, help="Wait for sync to complete (synchronous)")
@click.option('--force', is_flag=True, help='Force sync even if already running') @click.option("--force", is_flag=True, help="Force sync even if already running")
@click.pass_context @click.pass_context
def sync(ctx: click.Context, wait: bool, force: bool): def sync(ctx: click.Context, wait: bool, force: bool):
""" """
Sync all transactions with database Sync all transactions with database
""" """
api_client = LeggendAPIClient(ctx.obj.get("api_url")) api_client = LeggendAPIClient(ctx.obj.get("api_url"))
# Check if leggend service is available # Check if leggend service is available
if not api_client.health_check(): if not api_client.health_check():
error("Cannot connect to leggend service. Please ensure it's running.") error("Cannot connect to leggend service. Please ensure it's running.")
@@ -25,35 +25,37 @@ def sync(ctx: click.Context, wait: bool, force: bool):
# Run sync synchronously and wait for completion # Run sync synchronously and wait for completion
info("Starting synchronous sync...") info("Starting synchronous sync...")
result = api_client.sync_now(force=force) result = api_client.sync_now(force=force)
if result.get("success"): if result.get("success"):
success(f"Sync completed successfully!") success(f"Sync completed successfully!")
info(f"Accounts processed: {result.get('accounts_processed', 0)}") info(f"Accounts processed: {result.get('accounts_processed', 0)}")
info(f"Transactions added: {result.get('transactions_added', 0)}") info(f"Transactions added: {result.get('transactions_added', 0)}")
info(f"Balances updated: {result.get('balances_updated', 0)}") info(f"Balances updated: {result.get('balances_updated', 0)}")
if result.get('duration_seconds'): if result.get("duration_seconds"):
info(f"Duration: {result['duration_seconds']:.2f} seconds") info(f"Duration: {result['duration_seconds']:.2f} seconds")
if result.get('errors'): if result.get("errors"):
error(f"Errors encountered: {len(result['errors'])}") error(f"Errors encountered: {len(result['errors'])}")
for err in result['errors']: for err in result["errors"]:
error(f" - {err}") error(f" - {err}")
else: else:
error("Sync failed") error("Sync failed")
if result.get('errors'): if result.get("errors"):
for err in result['errors']: for err in result["errors"]:
error(f" - {err}") error(f" - {err}")
else: else:
# Trigger async sync # Trigger async sync
info("Starting background sync...") info("Starting background sync...")
result = api_client.trigger_sync(force=force) result = api_client.trigger_sync(force=force)
if result.get("sync_started"): if result.get("sync_started"):
success("Sync started successfully in the background") success("Sync started successfully in the background")
info("Use 'leggen sync --wait' to run synchronously or check status with API") info(
"Use 'leggen sync --wait' to run synchronously or check status with API"
)
else: else:
error("Failed to start sync") error("Failed to start sync")
except Exception as e: except Exception as e:
error(f"Sync failed: {str(e)}") error(f"Sync failed: {str(e)}")
return return

View File

@@ -7,7 +7,9 @@ from leggen.utils.text import datefmt, info, print_table
@cli.command() @cli.command()
@click.option("-a", "--account", type=str, help="Account ID") @click.option("-a", "--account", type=str, help="Account ID")
@click.option("-l", "--limit", type=int, default=50, help="Number of transactions to show") @click.option(
"-l", "--limit", type=int, default=50, help="Number of transactions to show"
)
@click.option("--full", is_flag=True, help="Show full transaction details") @click.option("--full", is_flag=True, help="Show full transaction details")
@click.pass_context @click.pass_context
def transactions(ctx: click.Context, account: str, limit: int, full: bool): def transactions(ctx: click.Context, account: str, limit: int, full: bool):
@@ -19,10 +21,12 @@ def transactions(ctx: click.Context, account: str, limit: int, full: bool):
If the --account option is used, it will only list transactions for that account. If the --account option is used, it will only list transactions for that account.
""" """
api_client = LeggendAPIClient(ctx.obj.get("api_url")) api_client = LeggendAPIClient(ctx.obj.get("api_url"))
# Check if leggend service is available # Check if leggend service is available
if not api_client.health_check(): if not api_client.health_check():
click.echo("Error: Cannot connect to leggend service. Please ensure it's running.") click.echo(
"Error: Cannot connect to leggend service. Please ensure it's running."
)
return return
try: try:
@@ -32,16 +36,14 @@ def transactions(ctx: click.Context, account: str, limit: int, full: bool):
transactions_data = api_client.get_account_transactions( transactions_data = api_client.get_account_transactions(
account, limit=limit, summary_only=not full account, limit=limit, summary_only=not full
) )
info(f"Bank: {account_details['institution_id']}") info(f"Bank: {account_details['institution_id']}")
info(f"IBAN: {account_details.get('iban', 'N/A')}") info(f"IBAN: {account_details.get('iban', 'N/A')}")
else: else:
# Get all transactions # Get all transactions
transactions_data = api_client.get_all_transactions( transactions_data = api_client.get_all_transactions(
limit=limit, limit=limit, summary_only=not full, account_id=account
summary_only=not full,
account_id=account
) )
# Format transactions for display # Format transactions for display
@@ -49,24 +51,32 @@ def transactions(ctx: click.Context, account: str, limit: int, full: bool):
# Full transaction details # Full transaction details
formatted_transactions = [] formatted_transactions = []
for txn in transactions_data: for txn in transactions_data:
formatted_transactions.append({ formatted_transactions.append(
"ID": txn["internal_transaction_id"][:12] + "...", {
"Date": datefmt(txn["transaction_date"]), "ID": txn["internal_transaction_id"][:12] + "...",
"Description": txn["description"][:50] + "..." if len(txn["description"]) > 50 else txn["description"], "Date": datefmt(txn["transaction_date"]),
"Amount": f"{txn['transaction_value']:.2f} {txn['transaction_currency']}", "Description": txn["description"][:50] + "..."
"Status": txn["transaction_status"].upper(), if len(txn["description"]) > 50
"Account": txn["account_id"][:8] + "...", else txn["description"],
}) "Amount": f"{txn['transaction_value']:.2f} {txn['transaction_currency']}",
"Status": txn["transaction_status"].upper(),
"Account": txn["account_id"][:8] + "...",
}
)
else: else:
# Summary view # Summary view
formatted_transactions = [] formatted_transactions = []
for txn in transactions_data: for txn in transactions_data:
formatted_transactions.append({ formatted_transactions.append(
"Date": datefmt(txn["date"]), {
"Description": txn["description"][:60] + "..." if len(txn["description"]) > 60 else txn["description"], "Date": datefmt(txn["date"]),
"Amount": f"{txn['amount']:.2f} {txn['currency']}", "Description": txn["description"][:60] + "..."
"Status": txn["status"].upper(), if len(txn["description"]) > 60
}) else txn["description"],
"Amount": f"{txn['amount']:.2f} {txn['currency']}",
"Status": txn["status"].upper(),
}
)
if formatted_transactions: if formatted_transactions:
print_table(formatted_transactions) print_table(formatted_transactions)

View File

@@ -90,10 +90,10 @@ class Group(click.Group):
@click.option( @click.option(
"--api-url", "--api-url",
type=str, type=str,
default=None, default="http://localhost:8000",
envvar="LEGGEND_API_URL", envvar="LEGGEND_API_URL",
show_envvar=True, show_envvar=True,
help="URL of the leggend API service (default: http://localhost:8000)", help="URL of the leggend API service",
) )
@click.group( @click.group(
cls=Group, cls=Group,
@@ -113,7 +113,7 @@ def cli(ctx: click.Context, api_url: str):
# Store API URL in context for commands to use # Store API URL in context for commands to use
if api_url: if api_url:
ctx.obj["api_url"] = api_url ctx.obj["api_url"] = api_url
# For backwards compatibility, still support direct GoCardless calls # For backwards compatibility, still support direct GoCardless calls
# This will be used as fallback if leggend service is not available # This will be used as fallback if leggend service is not available
try: try:

View File

View File

@@ -6,19 +6,19 @@ from pydantic import BaseModel
class AccountBalance(BaseModel): class AccountBalance(BaseModel):
"""Account balance model""" """Account balance model"""
amount: float amount: float
currency: str currency: str
balance_type: str balance_type: str
last_change_date: Optional[datetime] = None last_change_date: Optional[datetime] = None
class Config: class Config:
json_encoders = { json_encoders = {datetime: lambda v: v.isoformat() if v else None}
datetime: lambda v: v.isoformat() if v else None
}
class AccountDetails(BaseModel): class AccountDetails(BaseModel):
"""Account details model""" """Account details model"""
id: str id: str
institution_id: str institution_id: str
status: str status: str
@@ -28,15 +28,14 @@ class AccountDetails(BaseModel):
created: datetime created: datetime
last_accessed: Optional[datetime] = None last_accessed: Optional[datetime] = None
balances: List[AccountBalance] = [] balances: List[AccountBalance] = []
class Config: class Config:
json_encoders = { json_encoders = {datetime: lambda v: v.isoformat() if v else None}
datetime: lambda v: v.isoformat() if v else None
}
class Transaction(BaseModel): class Transaction(BaseModel):
"""Transaction model""" """Transaction model"""
internal_transaction_id: str internal_transaction_id: str
institution_id: str institution_id: str
iban: Optional[str] = None iban: Optional[str] = None
@@ -47,15 +46,14 @@ class Transaction(BaseModel):
transaction_currency: str transaction_currency: str
transaction_status: str # "booked" or "pending" transaction_status: str # "booked" or "pending"
raw_transaction: Dict[str, Any] raw_transaction: Dict[str, Any]
class Config: class Config:
json_encoders = { json_encoders = {datetime: lambda v: v.isoformat()}
datetime: lambda v: v.isoformat()
}
class TransactionSummary(BaseModel): class TransactionSummary(BaseModel):
"""Transaction summary for lists""" """Transaction summary for lists"""
internal_transaction_id: str internal_transaction_id: str
date: datetime date: datetime
description: str description: str
@@ -63,8 +61,6 @@ class TransactionSummary(BaseModel):
currency: str currency: str
status: str status: str
account_id: str account_id: str
class Config: class Config:
json_encoders = { json_encoders = {datetime: lambda v: v.isoformat()}
datetime: lambda v: v.isoformat()
}

View File

@@ -6,6 +6,7 @@ from pydantic import BaseModel
class BankInstitution(BaseModel): class BankInstitution(BaseModel):
"""Bank institution model""" """Bank institution model"""
id: str id: str
name: str name: str
bic: Optional[str] = None bic: Optional[str] = None
@@ -16,12 +17,14 @@ class BankInstitution(BaseModel):
class BankConnectionRequest(BaseModel): class BankConnectionRequest(BaseModel):
"""Request to connect to a bank""" """Request to connect to a bank"""
institution_id: str institution_id: str
redirect_url: Optional[str] = "http://localhost:8000/" redirect_url: Optional[str] = "http://localhost:8000/"
class BankRequisition(BaseModel): class BankRequisition(BaseModel):
"""Bank requisition/connection model""" """Bank requisition/connection model"""
id: str id: str
institution_id: str institution_id: str
status: str status: str
@@ -29,15 +32,14 @@ class BankRequisition(BaseModel):
created: datetime created: datetime
link: str link: str
accounts: List[str] = [] accounts: List[str] = []
class Config: class Config:
json_encoders = { json_encoders = {datetime: lambda v: v.isoformat()}
datetime: lambda v: v.isoformat()
}
class BankConnectionStatus(BaseModel): class BankConnectionStatus(BaseModel):
"""Bank connection status response""" """Bank connection status response"""
bank_id: str bank_id: str
bank_name: str bank_name: str
status: str status: str
@@ -45,8 +47,6 @@ class BankConnectionStatus(BaseModel):
created_at: datetime created_at: datetime
requisition_id: str requisition_id: str
accounts_count: int accounts_count: int
class Config: class Config:
json_encoders = { json_encoders = {datetime: lambda v: v.isoformat()}
datetime: lambda v: v.isoformat()
}

View File

@@ -6,6 +6,7 @@ from pydantic import BaseModel
class APIResponse(BaseModel): class APIResponse(BaseModel):
"""Base API response model""" """Base API response model"""
success: bool = True success: bool = True
message: Optional[str] = None message: Optional[str] = None
data: Optional[Any] = None data: Optional[Any] = None
@@ -13,6 +14,7 @@ class APIResponse(BaseModel):
class ErrorResponse(BaseModel): class ErrorResponse(BaseModel):
"""Error response model""" """Error response model"""
success: bool = False success: bool = False
message: str message: str
error_code: Optional[str] = None error_code: Optional[str] = None
@@ -21,7 +23,8 @@ class ErrorResponse(BaseModel):
class PaginatedResponse(BaseModel): class PaginatedResponse(BaseModel):
"""Paginated response model""" """Paginated response model"""
success: bool = True success: bool = True
data: list data: list
pagination: Dict[str, Any] pagination: Dict[str, Any]
message: Optional[str] = None message: Optional[str] = None

View File

@@ -5,12 +5,14 @@ from pydantic import BaseModel
class DiscordConfig(BaseModel): class DiscordConfig(BaseModel):
"""Discord notification configuration""" """Discord notification configuration"""
webhook: str webhook: str
enabled: bool = True enabled: bool = True
class TelegramConfig(BaseModel): class TelegramConfig(BaseModel):
"""Telegram notification configuration""" """Telegram notification configuration"""
token: str token: str
chat_id: int chat_id: int
enabled: bool = True enabled: bool = True
@@ -18,6 +20,7 @@ class TelegramConfig(BaseModel):
class NotificationFilters(BaseModel): class NotificationFilters(BaseModel):
"""Notification filters configuration""" """Notification filters configuration"""
case_insensitive: Dict[str, str] = {} case_insensitive: Dict[str, str] = {}
case_sensitive: Optional[Dict[str, str]] = None case_sensitive: Optional[Dict[str, str]] = None
amount_threshold: Optional[float] = None amount_threshold: Optional[float] = None
@@ -26,6 +29,7 @@ class NotificationFilters(BaseModel):
class NotificationSettings(BaseModel): class NotificationSettings(BaseModel):
"""Complete notification settings""" """Complete notification settings"""
discord: Optional[DiscordConfig] = None discord: Optional[DiscordConfig] = None
telegram: Optional[TelegramConfig] = None telegram: Optional[TelegramConfig] = None
filters: NotificationFilters = NotificationFilters() filters: NotificationFilters = NotificationFilters()
@@ -33,15 +37,17 @@ class NotificationSettings(BaseModel):
class NotificationTest(BaseModel): class NotificationTest(BaseModel):
"""Test notification request""" """Test notification request"""
service: str # "discord" or "telegram" service: str # "discord" or "telegram"
message: str = "Test notification from Leggen" message: str = "Test notification from Leggen"
class NotificationHistory(BaseModel): class NotificationHistory(BaseModel):
"""Notification history entry""" """Notification history entry"""
id: str id: str
service: str service: str
message: str message: str
status: str # "sent", "failed" status: str # "sent", "failed"
sent_at: str sent_at: str
error: Optional[str] = None error: Optional[str] = None

View File

@@ -6,12 +6,14 @@ from pydantic import BaseModel
class SyncRequest(BaseModel): class SyncRequest(BaseModel):
"""Request to trigger a sync""" """Request to trigger a sync"""
account_ids: Optional[list[str]] = None # If None, sync all accounts account_ids: Optional[list[str]] = None # If None, sync all accounts
force: bool = False # Force sync even if recently synced force: bool = False # Force sync even if recently synced
class SyncStatus(BaseModel): class SyncStatus(BaseModel):
"""Sync operation status""" """Sync operation status"""
is_running: bool is_running: bool
last_sync: Optional[datetime] = None last_sync: Optional[datetime] = None
next_sync: Optional[datetime] = None next_sync: Optional[datetime] = None
@@ -19,15 +21,14 @@ class SyncStatus(BaseModel):
total_accounts: int = 0 total_accounts: int = 0
transactions_added: int = 0 transactions_added: int = 0
errors: list[str] = [] errors: list[str] = []
class Config: class Config:
json_encoders = { json_encoders = {datetime: lambda v: v.isoformat() if v else None}
datetime: lambda v: v.isoformat() if v else None
}
class SyncResult(BaseModel): class SyncResult(BaseModel):
"""Result of a sync operation""" """Result of a sync operation"""
success: bool success: bool
accounts_processed: int accounts_processed: int
transactions_added: int transactions_added: int
@@ -37,19 +38,18 @@ class SyncResult(BaseModel):
errors: list[str] = [] errors: list[str] = []
started_at: datetime started_at: datetime
completed_at: datetime completed_at: datetime
class Config: class Config:
json_encoders = { json_encoders = {datetime: lambda v: v.isoformat()}
datetime: lambda v: v.isoformat()
}
class SchedulerConfig(BaseModel): class SchedulerConfig(BaseModel):
"""Scheduler configuration model""" """Scheduler configuration model"""
enabled: bool = True enabled: bool = True
hour: Optional[int] = 3 hour: Optional[int] = 3
minute: Optional[int] = 0 minute: Optional[int] = 0
cron: Optional[str] = None # Custom cron expression cron: Optional[str] = None # Custom cron expression
class Config: class Config:
extra = "forbid" extra = "forbid"

View File

@@ -3,7 +3,12 @@ from fastapi import APIRouter, HTTPException, Query
from loguru import logger from loguru import logger
from leggend.api.models.common import APIResponse from leggend.api.models.common import APIResponse
from leggend.api.models.accounts import AccountDetails, AccountBalance, Transaction, TransactionSummary from leggend.api.models.accounts import (
AccountDetails,
AccountBalance,
Transaction,
TransactionSummary,
)
from leggend.services.gocardless_service import GoCardlessService from leggend.services.gocardless_service import GoCardlessService
from leggend.services.database_service import DatabaseService from leggend.services.database_service import DatabaseService
@@ -17,50 +22,56 @@ async def get_all_accounts() -> APIResponse:
"""Get all connected accounts""" """Get all connected accounts"""
try: try:
requisitions_data = await gocardless_service.get_requisitions() requisitions_data = await gocardless_service.get_requisitions()
all_accounts = set() all_accounts = set()
for req in requisitions_data.get("results", []): for req in requisitions_data.get("results", []):
all_accounts.update(req.get("accounts", [])) all_accounts.update(req.get("accounts", []))
accounts = [] accounts = []
for account_id in all_accounts: for account_id in all_accounts:
try: try:
account_details = await gocardless_service.get_account_details(account_id) account_details = await gocardless_service.get_account_details(
balances_data = await gocardless_service.get_account_balances(account_id) account_id
)
balances_data = await gocardless_service.get_account_balances(
account_id
)
# Process balances # Process balances
balances = [] balances = []
for balance in balances_data.get("balances", []): for balance in balances_data.get("balances", []):
balance_amount = balance["balanceAmount"] balance_amount = balance["balanceAmount"]
balances.append(AccountBalance( balances.append(
amount=float(balance_amount["amount"]), AccountBalance(
currency=balance_amount["currency"], amount=float(balance_amount["amount"]),
balance_type=balance["balanceType"], currency=balance_amount["currency"],
last_change_date=balance.get("lastChangeDateTime") balance_type=balance["balanceType"],
)) last_change_date=balance.get("lastChangeDateTime"),
)
accounts.append(AccountDetails( )
id=account_details["id"],
institution_id=account_details["institution_id"], accounts.append(
status=account_details["status"], AccountDetails(
iban=account_details.get("iban"), id=account_details["id"],
name=account_details.get("name"), institution_id=account_details["institution_id"],
currency=account_details.get("currency"), status=account_details["status"],
created=account_details["created"], iban=account_details.get("iban"),
last_accessed=account_details.get("last_accessed"), name=account_details.get("name"),
balances=balances currency=account_details.get("currency"),
)) created=account_details["created"],
last_accessed=account_details.get("last_accessed"),
balances=balances,
)
)
except Exception as e: except Exception as e:
logger.error(f"Failed to get details for account {account_id}: {e}") logger.error(f"Failed to get details for account {account_id}: {e}")
continue continue
return APIResponse( return APIResponse(
success=True, success=True, data=accounts, message=f"Retrieved {len(accounts)} accounts"
data=accounts,
message=f"Retrieved {len(accounts)} accounts"
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get accounts: {e}") logger.error(f"Failed to get accounts: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get accounts: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to get accounts: {str(e)}")
@@ -72,18 +83,20 @@ async def get_account_details(account_id: str) -> APIResponse:
try: try:
account_details = await gocardless_service.get_account_details(account_id) account_details = await gocardless_service.get_account_details(account_id)
balances_data = await gocardless_service.get_account_balances(account_id) balances_data = await gocardless_service.get_account_balances(account_id)
# Process balances # Process balances
balances = [] balances = []
for balance in balances_data.get("balances", []): for balance in balances_data.get("balances", []):
balance_amount = balance["balanceAmount"] balance_amount = balance["balanceAmount"]
balances.append(AccountBalance( balances.append(
amount=float(balance_amount["amount"]), AccountBalance(
currency=balance_amount["currency"], amount=float(balance_amount["amount"]),
balance_type=balance["balanceType"], currency=balance_amount["currency"],
last_change_date=balance.get("lastChangeDateTime") balance_type=balance["balanceType"],
)) last_change_date=balance.get("lastChangeDateTime"),
)
)
account = AccountDetails( account = AccountDetails(
id=account_details["id"], id=account_details["id"],
institution_id=account_details["institution_id"], institution_id=account_details["institution_id"],
@@ -93,15 +106,15 @@ async def get_account_details(account_id: str) -> APIResponse:
currency=account_details.get("currency"), currency=account_details.get("currency"),
created=account_details["created"], created=account_details["created"],
last_accessed=account_details.get("last_accessed"), last_accessed=account_details.get("last_accessed"),
balances=balances balances=balances,
) )
return APIResponse( return APIResponse(
success=True, success=True,
data=account, data=account,
message=f"Account details retrieved for {account_id}" message=f"Account details retrieved for {account_id}",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get account details for {account_id}: {e}") logger.error(f"Failed to get account details for {account_id}: {e}")
raise HTTPException(status_code=404, detail=f"Account not found: {str(e)}") raise HTTPException(status_code=404, detail=f"Account not found: {str(e)}")
@@ -112,23 +125,25 @@ async def get_account_balances(account_id: str) -> APIResponse:
"""Get balances for a specific account""" """Get balances for a specific account"""
try: try:
balances_data = await gocardless_service.get_account_balances(account_id) balances_data = await gocardless_service.get_account_balances(account_id)
balances = [] balances = []
for balance in balances_data.get("balances", []): for balance in balances_data.get("balances", []):
balance_amount = balance["balanceAmount"] balance_amount = balance["balanceAmount"]
balances.append(AccountBalance( balances.append(
amount=float(balance_amount["amount"]), AccountBalance(
currency=balance_amount["currency"], amount=float(balance_amount["amount"]),
balance_type=balance["balanceType"], currency=balance_amount["currency"],
last_change_date=balance.get("lastChangeDateTime") balance_type=balance["balanceType"],
)) last_change_date=balance.get("lastChangeDateTime"),
)
)
return APIResponse( return APIResponse(
success=True, success=True,
data=balances, data=balances,
message=f"Retrieved {len(balances)} balances for account {account_id}" message=f"Retrieved {len(balances)} balances for account {account_id}",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get balances for account {account_id}: {e}") logger.error(f"Failed to get balances for account {account_id}: {e}")
raise HTTPException(status_code=404, detail=f"Failed to get balances: {str(e)}") raise HTTPException(status_code=404, detail=f"Failed to get balances: {str(e)}")
@@ -139,22 +154,26 @@ async def get_account_transactions(
account_id: str, account_id: str,
limit: Optional[int] = Query(default=100, le=500), limit: Optional[int] = Query(default=100, le=500),
offset: Optional[int] = Query(default=0, ge=0), offset: Optional[int] = Query(default=0, ge=0),
summary_only: bool = Query(default=False, description="Return transaction summaries only") summary_only: bool = Query(
default=False, description="Return transaction summaries only"
),
) -> APIResponse: ) -> APIResponse:
"""Get transactions for a specific account""" """Get transactions for a specific account"""
try: try:
account_details = await gocardless_service.get_account_details(account_id) account_details = await gocardless_service.get_account_details(account_id)
transactions_data = await gocardless_service.get_account_transactions(account_id) transactions_data = await gocardless_service.get_account_transactions(
account_id
)
# Process transactions # Process transactions
processed_transactions = database_service.process_transactions( processed_transactions = database_service.process_transactions(
account_id, account_details, transactions_data account_id, account_details, transactions_data
) )
# Apply pagination # Apply pagination
total_transactions = len(processed_transactions) total_transactions = len(processed_transactions)
paginated_transactions = processed_transactions[offset:offset + limit] paginated_transactions = processed_transactions[offset : offset + limit]
if summary_only: if summary_only:
# Return simplified transaction summaries # Return simplified transaction summaries
summaries = [ summaries = [
@@ -165,7 +184,7 @@ async def get_account_transactions(
amount=txn["transactionValue"], amount=txn["transactionValue"],
currency=txn["transactionCurrency"], currency=txn["transactionCurrency"],
status=txn["transactionStatus"], status=txn["transactionStatus"],
account_id=txn["accountId"] account_id=txn["accountId"],
) )
for txn in paginated_transactions for txn in paginated_transactions
] ]
@@ -183,18 +202,20 @@ async def get_account_transactions(
transaction_value=txn["transactionValue"], transaction_value=txn["transactionValue"],
transaction_currency=txn["transactionCurrency"], transaction_currency=txn["transactionCurrency"],
transaction_status=txn["transactionStatus"], transaction_status=txn["transactionStatus"],
raw_transaction=txn["rawTransaction"] raw_transaction=txn["rawTransaction"],
) )
for txn in paginated_transactions for txn in paginated_transactions
] ]
data = transactions data = transactions
return APIResponse( return APIResponse(
success=True, success=True,
data=data, data=data,
message=f"Retrieved {len(data)} transactions (showing {offset + 1}-{offset + len(data)} of {total_transactions})" message=f"Retrieved {len(data)} transactions (showing {offset + 1}-{offset + len(data)} of {total_transactions})",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get transactions for account {account_id}: {e}") logger.error(f"Failed to get transactions for account {account_id}: {e}")
raise HTTPException(status_code=404, detail=f"Failed to get transactions: {str(e)}") raise HTTPException(
status_code=404, detail=f"Failed to get transactions: {str(e)}"
)

View File

@@ -4,10 +4,10 @@ from loguru import logger
from leggend.api.models.common import APIResponse, ErrorResponse from leggend.api.models.common import APIResponse, ErrorResponse
from leggend.api.models.banks import ( from leggend.api.models.banks import (
BankInstitution, BankInstitution,
BankConnectionRequest, BankConnectionRequest,
BankRequisition, BankRequisition,
BankConnectionStatus BankConnectionStatus,
) )
from leggend.services.gocardless_service import GoCardlessService from leggend.services.gocardless_service import GoCardlessService
from leggend.utils.gocardless import REQUISITION_STATUS from leggend.utils.gocardless import REQUISITION_STATUS
@@ -18,12 +18,12 @@ gocardless_service = GoCardlessService()
@router.get("/banks/institutions", response_model=APIResponse) @router.get("/banks/institutions", response_model=APIResponse)
async def get_bank_institutions( async def get_bank_institutions(
country: str = Query(default="PT", description="Country code (e.g., PT, ES, FR)") country: str = Query(default="PT", description="Country code (e.g., PT, ES, FR)"),
) -> APIResponse: ) -> APIResponse:
"""Get available bank institutions for a country""" """Get available bank institutions for a country"""
try: try:
institutions_data = await gocardless_service.get_institutions(country) institutions_data = await gocardless_service.get_institutions(country)
institutions = [ institutions = [
BankInstitution( BankInstitution(
id=inst["id"], id=inst["id"],
@@ -31,20 +31,22 @@ async def get_bank_institutions(
bic=inst.get("bic"), bic=inst.get("bic"),
transaction_total_days=inst["transaction_total_days"], transaction_total_days=inst["transaction_total_days"],
countries=inst["countries"], countries=inst["countries"],
logo=inst.get("logo") logo=inst.get("logo"),
) )
for inst in institutions_data for inst in institutions_data
] ]
return APIResponse( return APIResponse(
success=True, success=True,
data=institutions, data=institutions,
message=f"Found {len(institutions)} institutions for {country}" message=f"Found {len(institutions)} institutions for {country}",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get institutions for {country}: {e}") logger.error(f"Failed to get institutions for {country}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get institutions: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to get institutions: {str(e)}"
)
@router.post("/banks/connect", response_model=APIResponse) @router.post("/banks/connect", response_model=APIResponse)
@@ -52,28 +54,29 @@ async def connect_to_bank(request: BankConnectionRequest) -> APIResponse:
"""Create a connection to a bank (requisition)""" """Create a connection to a bank (requisition)"""
try: try:
requisition_data = await gocardless_service.create_requisition( requisition_data = await gocardless_service.create_requisition(
request.institution_id, request.institution_id, request.redirect_url
request.redirect_url
) )
requisition = BankRequisition( requisition = BankRequisition(
id=requisition_data["id"], id=requisition_data["id"],
institution_id=requisition_data["institution_id"], institution_id=requisition_data["institution_id"],
status=requisition_data["status"], status=requisition_data["status"],
created=requisition_data["created"], created=requisition_data["created"],
link=requisition_data["link"], link=requisition_data["link"],
accounts=requisition_data.get("accounts", []) accounts=requisition_data.get("accounts", []),
) )
return APIResponse( return APIResponse(
success=True, success=True,
data=requisition, data=requisition,
message=f"Bank connection created. Please visit the link to authorize." message=f"Bank connection created. Please visit the link to authorize.",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to connect to bank {request.institution_id}: {e}") logger.error(f"Failed to connect to bank {request.institution_id}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to connect to bank: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to connect to bank: {str(e)}"
)
@router.get("/banks/status", response_model=APIResponse) @router.get("/banks/status", response_model=APIResponse)
@@ -81,31 +84,37 @@ async def get_bank_connections_status() -> APIResponse:
"""Get status of all bank connections""" """Get status of all bank connections"""
try: try:
requisitions_data = await gocardless_service.get_requisitions() requisitions_data = await gocardless_service.get_requisitions()
connections = [] connections = []
for req in requisitions_data.get("results", []): for req in requisitions_data.get("results", []):
status = req["status"] status = req["status"]
status_display = REQUISITION_STATUS.get(status, "UNKNOWN") status_display = REQUISITION_STATUS.get(status, "UNKNOWN")
connections.append(BankConnectionStatus( connections.append(
bank_id=req["institution_id"], BankConnectionStatus(
bank_name=req["institution_id"], # Could be enhanced with actual bank names bank_id=req["institution_id"],
status=status, bank_name=req[
status_display=status_display, "institution_id"
created_at=req["created"], ], # Could be enhanced with actual bank names
requisition_id=req["id"], status=status,
accounts_count=len(req.get("accounts", [])) status_display=status_display,
)) created_at=req["created"],
requisition_id=req["id"],
accounts_count=len(req.get("accounts", [])),
)
)
return APIResponse( return APIResponse(
success=True, success=True,
data=connections, data=connections,
message=f"Found {len(connections)} bank connections" message=f"Found {len(connections)} bank connections",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get bank connection status: {e}") logger.error(f"Failed to get bank connection status: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get bank status: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to get bank status: {str(e)}"
)
@router.delete("/banks/connections/{requisition_id}", response_model=APIResponse) @router.delete("/banks/connections/{requisition_id}", response_model=APIResponse)
@@ -116,12 +125,14 @@ async def delete_bank_connection(requisition_id: str) -> APIResponse:
# For now, return success # For now, return success
return APIResponse( return APIResponse(
success=True, success=True,
message=f"Bank connection {requisition_id} deleted successfully" message=f"Bank connection {requisition_id} deleted successfully",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to delete bank connection {requisition_id}: {e}") logger.error(f"Failed to delete bank connection {requisition_id}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to delete connection: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to delete connection: {str(e)}"
)
@router.get("/banks/countries", response_model=APIResponse) @router.get("/banks/countries", response_model=APIResponse)
@@ -160,9 +171,9 @@ async def get_supported_countries() -> APIResponse:
{"code": "SE", "name": "Sweden"}, {"code": "SE", "name": "Sweden"},
{"code": "GB", "name": "United Kingdom"}, {"code": "GB", "name": "United Kingdom"},
] ]
return APIResponse( return APIResponse(
success=True, success=True,
data=countries, data=countries,
message="Supported countries retrieved successfully" message="Supported countries retrieved successfully",
) )

View File

@@ -4,11 +4,11 @@ from loguru import logger
from leggend.api.models.common import APIResponse from leggend.api.models.common import APIResponse
from leggend.api.models.notifications import ( from leggend.api.models.notifications import (
NotificationSettings, NotificationSettings,
NotificationTest, NotificationTest,
DiscordConfig, DiscordConfig,
TelegramConfig, TelegramConfig,
NotificationFilters NotificationFilters,
) )
from leggend.services.notification_service import NotificationService from leggend.services.notification_service import NotificationService
from leggend.config import config from leggend.config import config
@@ -23,38 +23,44 @@ async def get_notification_settings() -> APIResponse:
try: try:
notifications_config = config.notifications_config notifications_config = config.notifications_config
filters_config = config.filters_config filters_config = config.filters_config
# Build response safely without exposing secrets # Build response safely without exposing secrets
discord_config = notifications_config.get("discord", {}) discord_config = notifications_config.get("discord", {})
telegram_config = notifications_config.get("telegram", {}) telegram_config = notifications_config.get("telegram", {})
settings = NotificationSettings( settings = NotificationSettings(
discord=DiscordConfig( discord=DiscordConfig(
webhook="***" if discord_config.get("webhook") else "", webhook="***" if discord_config.get("webhook") else "",
enabled=discord_config.get("enabled", True) enabled=discord_config.get("enabled", True),
) if discord_config.get("webhook") else None, )
if discord_config.get("webhook")
else None,
telegram=TelegramConfig( telegram=TelegramConfig(
token="***" if telegram_config.get("token") else "", token="***" if telegram_config.get("token") else "",
chat_id=telegram_config.get("chat_id", 0), chat_id=telegram_config.get("chat_id", 0),
enabled=telegram_config.get("enabled", True) enabled=telegram_config.get("enabled", True),
) if telegram_config.get("token") else None, )
if telegram_config.get("token")
else None,
filters=NotificationFilters( filters=NotificationFilters(
case_insensitive=filters_config.get("case-insensitive", {}), case_insensitive=filters_config.get("case-insensitive", {}),
case_sensitive=filters_config.get("case-sensitive"), case_sensitive=filters_config.get("case-sensitive"),
amount_threshold=filters_config.get("amount_threshold"), amount_threshold=filters_config.get("amount_threshold"),
keywords=filters_config.get("keywords", []) keywords=filters_config.get("keywords", []),
) ),
) )
return APIResponse( return APIResponse(
success=True, success=True,
data=settings, data=settings,
message="Notification settings retrieved successfully" message="Notification settings retrieved successfully",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get notification settings: {e}") logger.error(f"Failed to get notification settings: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get notification settings: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to get notification settings: {str(e)}"
)
@router.put("/notifications/settings", response_model=APIResponse) @router.put("/notifications/settings", response_model=APIResponse)
@@ -63,20 +69,20 @@ async def update_notification_settings(settings: NotificationSettings) -> APIRes
try: try:
# Update notifications config # Update notifications config
notifications_config = {} notifications_config = {}
if settings.discord: if settings.discord:
notifications_config["discord"] = { notifications_config["discord"] = {
"webhook": settings.discord.webhook, "webhook": settings.discord.webhook,
"enabled": settings.discord.enabled "enabled": settings.discord.enabled,
} }
if settings.telegram: if settings.telegram:
notifications_config["telegram"] = { notifications_config["telegram"] = {
"token": settings.telegram.token, "token": settings.telegram.token,
"chat_id": settings.telegram.chat_id, "chat_id": settings.telegram.chat_id,
"enabled": settings.telegram.enabled "enabled": settings.telegram.enabled,
} }
# Update filters config # Update filters config
filters_config = {} filters_config = {}
if settings.filters.case_insensitive: if settings.filters.case_insensitive:
@@ -87,22 +93,24 @@ async def update_notification_settings(settings: NotificationSettings) -> APIRes
filters_config["amount_threshold"] = settings.filters.amount_threshold filters_config["amount_threshold"] = settings.filters.amount_threshold
if settings.filters.keywords: if settings.filters.keywords:
filters_config["keywords"] = settings.filters.keywords filters_config["keywords"] = settings.filters.keywords
# Save to config # Save to config
if notifications_config: if notifications_config:
config.update_section("notifications", notifications_config) config.update_section("notifications", notifications_config)
if filters_config: if filters_config:
config.update_section("filters", filters_config) config.update_section("filters", filters_config)
return APIResponse( return APIResponse(
success=True, success=True,
data={"updated": True}, data={"updated": True},
message="Notification settings updated successfully" message="Notification settings updated successfully",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to update notification settings: {e}") logger.error(f"Failed to update notification settings: {e}")
raise HTTPException(status_code=500, detail=f"Failed to update notification settings: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to update notification settings: {str(e)}"
)
@router.post("/notifications/test", response_model=APIResponse) @router.post("/notifications/test", response_model=APIResponse)
@@ -110,25 +118,26 @@ async def test_notification(test_request: NotificationTest) -> APIResponse:
"""Send a test notification""" """Send a test notification"""
try: try:
success = await notification_service.send_test_notification( success = await notification_service.send_test_notification(
test_request.service, test_request.service, test_request.message
test_request.message
) )
if success: if success:
return APIResponse( return APIResponse(
success=True, success=True,
data={"sent": True}, data={"sent": True},
message=f"Test notification sent to {test_request.service} successfully" message=f"Test notification sent to {test_request.service} successfully",
) )
else: else:
return APIResponse( return APIResponse(
success=False, success=False,
message=f"Failed to send test notification to {test_request.service}" message=f"Failed to send test notification to {test_request.service}",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to send test notification: {e}") logger.error(f"Failed to send test notification: {e}")
raise HTTPException(status_code=500, detail=f"Failed to send test notification: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to send test notification: {str(e)}"
)
@router.get("/notifications/services", response_model=APIResponse) @router.get("/notifications/services", response_model=APIResponse)
@@ -136,37 +145,41 @@ async def get_notification_services() -> APIResponse:
"""Get available notification services and their status""" """Get available notification services and their status"""
try: try:
notifications_config = config.notifications_config notifications_config = config.notifications_config
services = { services = {
"discord": { "discord": {
"name": "Discord", "name": "Discord",
"enabled": bool(notifications_config.get("discord", {}).get("webhook")), "enabled": bool(notifications_config.get("discord", {}).get("webhook")),
"configured": bool(notifications_config.get("discord", {}).get("webhook")), "configured": bool(
"active": notifications_config.get("discord", {}).get("enabled", True) notifications_config.get("discord", {}).get("webhook")
),
"active": notifications_config.get("discord", {}).get("enabled", True),
}, },
"telegram": { "telegram": {
"name": "Telegram", "name": "Telegram",
"enabled": bool( "enabled": bool(
notifications_config.get("telegram", {}).get("token") and notifications_config.get("telegram", {}).get("token")
notifications_config.get("telegram", {}).get("chat_id") and notifications_config.get("telegram", {}).get("chat_id")
), ),
"configured": bool( "configured": bool(
notifications_config.get("telegram", {}).get("token") and notifications_config.get("telegram", {}).get("token")
notifications_config.get("telegram", {}).get("chat_id") and notifications_config.get("telegram", {}).get("chat_id")
), ),
"active": notifications_config.get("telegram", {}).get("enabled", True) "active": notifications_config.get("telegram", {}).get("enabled", True),
} },
} }
return APIResponse( return APIResponse(
success=True, success=True,
data=services, data=services,
message="Notification services status retrieved successfully" message="Notification services status retrieved successfully",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get notification services: {e}") logger.error(f"Failed to get notification services: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get notification services: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to get notification services: {str(e)}"
)
@router.delete("/notifications/settings/{service}", response_model=APIResponse) @router.delete("/notifications/settings/{service}", response_model=APIResponse)
@@ -174,19 +187,23 @@ async def delete_notification_service(service: str) -> APIResponse:
"""Delete/disable a notification service""" """Delete/disable a notification service"""
try: try:
if service not in ["discord", "telegram"]: if service not in ["discord", "telegram"]:
raise HTTPException(status_code=400, detail="Service must be 'discord' or 'telegram'") raise HTTPException(
status_code=400, detail="Service must be 'discord' or 'telegram'"
)
notifications_config = config.notifications_config.copy() notifications_config = config.notifications_config.copy()
if service in notifications_config: if service in notifications_config:
del notifications_config[service] del notifications_config[service]
config.update_section("notifications", notifications_config) config.update_section("notifications", notifications_config)
return APIResponse( return APIResponse(
success=True, success=True,
data={"deleted": service}, data={"deleted": service},
message=f"{service.capitalize()} notification service deleted successfully" message=f"{service.capitalize()} notification service deleted successfully",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to delete notification service {service}: {e}") logger.error(f"Failed to delete notification service {service}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to delete notification service: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to delete notification service: {str(e)}"
)

View File

@@ -17,27 +17,26 @@ async def get_sync_status() -> APIResponse:
"""Get current sync status""" """Get current sync status"""
try: try:
status = await sync_service.get_sync_status() status = await sync_service.get_sync_status()
# Add scheduler information # Add scheduler information
next_sync_time = scheduler.get_next_sync_time() next_sync_time = scheduler.get_next_sync_time()
if next_sync_time: if next_sync_time:
status.next_sync = next_sync_time status.next_sync = next_sync_time
return APIResponse( return APIResponse(
success=True, success=True, data=status, message="Sync status retrieved successfully"
data=status,
message="Sync status retrieved successfully"
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get sync status: {e}") logger.error(f"Failed to get sync status: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get sync status: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to get sync status: {str(e)}"
)
@router.post("/sync", response_model=APIResponse) @router.post("/sync", response_model=APIResponse)
async def trigger_sync( async def trigger_sync(
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks, sync_request: Optional[SyncRequest] = None
sync_request: Optional[SyncRequest] = None
) -> APIResponse: ) -> APIResponse:
"""Trigger a manual sync operation""" """Trigger a manual sync operation"""
try: try:
@@ -46,32 +45,37 @@ async def trigger_sync(
if status.is_running and not (sync_request and sync_request.force): if status.is_running and not (sync_request and sync_request.force):
return APIResponse( return APIResponse(
success=False, success=False,
message="Sync is already running. Use 'force: true' to override." message="Sync is already running. Use 'force: true' to override.",
) )
# Determine what to sync # Determine what to sync
if sync_request and sync_request.account_ids: if sync_request and sync_request.account_ids:
# Sync specific accounts in background # Sync specific accounts in background
background_tasks.add_task( background_tasks.add_task(
sync_service.sync_specific_accounts, sync_service.sync_specific_accounts,
sync_request.account_ids, sync_request.account_ids,
sync_request.force if sync_request else False sync_request.force if sync_request else False,
)
message = (
f"Started sync for {len(sync_request.account_ids)} specific accounts"
) )
message = f"Started sync for {len(sync_request.account_ids)} specific accounts"
else: else:
# Sync all accounts in background # Sync all accounts in background
background_tasks.add_task( background_tasks.add_task(
sync_service.sync_all_accounts, sync_service.sync_all_accounts,
sync_request.force if sync_request else False sync_request.force if sync_request else False,
) )
message = "Started sync for all accounts" message = "Started sync for all accounts"
return APIResponse( return APIResponse(
success=True, success=True,
data={"sync_started": True, "force": sync_request.force if sync_request else False}, data={
message=message "sync_started": True,
"force": sync_request.force if sync_request else False,
},
message=message,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to trigger sync: {e}") logger.error(f"Failed to trigger sync: {e}")
raise HTTPException(status_code=500, detail=f"Failed to trigger sync: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to trigger sync: {str(e)}")
@@ -83,20 +87,21 @@ async def sync_now(sync_request: Optional[SyncRequest] = None) -> APIResponse:
try: try:
if sync_request and sync_request.account_ids: if sync_request and sync_request.account_ids:
result = await sync_service.sync_specific_accounts( result = await sync_service.sync_specific_accounts(
sync_request.account_ids, sync_request.account_ids, sync_request.force
sync_request.force
) )
else: else:
result = await sync_service.sync_all_accounts( result = await sync_service.sync_all_accounts(
sync_request.force if sync_request else False sync_request.force if sync_request else False
) )
return APIResponse( return APIResponse(
success=result.success, success=result.success,
data=result, data=result,
message="Sync completed" if result.success else f"Sync failed with {len(result.errors)} errors" message="Sync completed"
if result.success
else f"Sync failed with {len(result.errors)} errors",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to run sync: {e}") logger.error(f"Failed to run sync: {e}")
raise HTTPException(status_code=500, detail=f"Failed to run sync: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to run sync: {str(e)}")
@@ -108,22 +113,28 @@ async def get_scheduler_config() -> APIResponse:
try: try:
scheduler_config = config.scheduler_config scheduler_config = config.scheduler_config
next_sync_time = scheduler.get_next_sync_time() next_sync_time = scheduler.get_next_sync_time()
response_data = { response_data = {
**scheduler_config, **scheduler_config,
"next_scheduled_sync": next_sync_time.isoformat() if next_sync_time else None, "next_scheduled_sync": next_sync_time.isoformat()
"is_running": scheduler.scheduler.running if hasattr(scheduler, 'scheduler') else False if next_sync_time
else None,
"is_running": scheduler.scheduler.running
if hasattr(scheduler, "scheduler")
else False,
} }
return APIResponse( return APIResponse(
success=True, success=True,
data=response_data, data=response_data,
message="Scheduler configuration retrieved successfully" message="Scheduler configuration retrieved successfully",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get scheduler config: {e}") logger.error(f"Failed to get scheduler config: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get scheduler config: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to get scheduler config: {str(e)}"
)
@router.put("/sync/scheduler", response_model=APIResponse) @router.put("/sync/scheduler", response_model=APIResponse)
@@ -135,26 +146,32 @@ async def update_scheduler_config(scheduler_config: SchedulerConfig) -> APIRespo
try: try:
cron_parts = scheduler_config.cron.split() cron_parts = scheduler_config.cron.split()
if len(cron_parts) != 5: if len(cron_parts) != 5:
raise ValueError("Cron expression must have 5 parts: minute hour day month day_of_week") raise ValueError(
"Cron expression must have 5 parts: minute hour day month day_of_week"
)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid cron expression: {str(e)}") raise HTTPException(
status_code=400, detail=f"Invalid cron expression: {str(e)}"
)
# Update configuration # Update configuration
schedule_data = scheduler_config.dict(exclude_none=True) schedule_data = scheduler_config.dict(exclude_none=True)
config.update_section("scheduler", {"sync": schedule_data}) config.update_section("scheduler", {"sync": schedule_data})
# Reschedule the job # Reschedule the job
scheduler.reschedule_sync(schedule_data) scheduler.reschedule_sync(schedule_data)
return APIResponse( return APIResponse(
success=True, success=True,
data=schedule_data, data=schedule_data,
message="Scheduler configuration updated successfully" message="Scheduler configuration updated successfully",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to update scheduler config: {e}") logger.error(f"Failed to update scheduler config: {e}")
raise HTTPException(status_code=500, detail=f"Failed to update scheduler config: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to update scheduler config: {str(e)}"
)
@router.post("/sync/scheduler/start", response_model=APIResponse) @router.post("/sync/scheduler/start", response_model=APIResponse)
@@ -163,37 +180,29 @@ async def start_scheduler() -> APIResponse:
try: try:
if not scheduler.scheduler.running: if not scheduler.scheduler.running:
scheduler.start() scheduler.start()
return APIResponse( return APIResponse(success=True, message="Scheduler started successfully")
success=True,
message="Scheduler started successfully"
)
else: else:
return APIResponse( return APIResponse(success=True, message="Scheduler is already running")
success=True,
message="Scheduler is already running"
)
except Exception as e: except Exception as e:
logger.error(f"Failed to start scheduler: {e}") logger.error(f"Failed to start scheduler: {e}")
raise HTTPException(status_code=500, detail=f"Failed to start scheduler: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to start scheduler: {str(e)}"
)
@router.post("/sync/scheduler/stop", response_model=APIResponse) @router.post("/sync/scheduler/stop", response_model=APIResponse)
async def stop_scheduler() -> APIResponse: async def stop_scheduler() -> APIResponse:
"""Stop the background scheduler""" """Stop the background scheduler"""
try: try:
if scheduler.scheduler.running: if scheduler.scheduler.running:
scheduler.shutdown() scheduler.shutdown()
return APIResponse( return APIResponse(success=True, message="Scheduler stopped successfully")
success=True,
message="Scheduler stopped successfully"
)
else: else:
return APIResponse( return APIResponse(success=True, message="Scheduler is already stopped")
success=True,
message="Scheduler is already stopped"
)
except Exception as e: except Exception as e:
logger.error(f"Failed to stop scheduler: {e}") logger.error(f"Failed to stop scheduler: {e}")
raise HTTPException(status_code=500, detail=f"Failed to stop scheduler: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to stop scheduler: {str(e)}"
)

View File

@@ -17,95 +17,111 @@ database_service = DatabaseService()
async def get_all_transactions( async def get_all_transactions(
limit: Optional[int] = Query(default=100, le=500), limit: Optional[int] = Query(default=100, le=500),
offset: Optional[int] = Query(default=0, ge=0), offset: Optional[int] = Query(default=0, ge=0),
summary_only: bool = Query(default=True, description="Return transaction summaries only"), summary_only: bool = Query(
date_from: Optional[str] = Query(default=None, description="Filter from date (YYYY-MM-DD)"), default=True, description="Return transaction summaries only"
date_to: Optional[str] = Query(default=None, description="Filter to date (YYYY-MM-DD)"), ),
min_amount: Optional[float] = Query(default=None, description="Minimum transaction amount"), date_from: Optional[str] = Query(
max_amount: Optional[float] = Query(default=None, description="Maximum transaction amount"), default=None, description="Filter from date (YYYY-MM-DD)"
search: Optional[str] = Query(default=None, description="Search in transaction descriptions"), ),
account_id: Optional[str] = Query(default=None, description="Filter by account ID") date_to: Optional[str] = Query(
default=None, description="Filter to date (YYYY-MM-DD)"
),
min_amount: Optional[float] = Query(
default=None, description="Minimum transaction amount"
),
max_amount: Optional[float] = Query(
default=None, description="Maximum transaction amount"
),
search: Optional[str] = Query(
default=None, description="Search in transaction descriptions"
),
account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
) -> APIResponse: ) -> APIResponse:
"""Get all transactions across all accounts with filtering options""" """Get all transactions across all accounts with filtering options"""
try: try:
# Get all requisitions and accounts # Get all requisitions and accounts
requisitions_data = await gocardless_service.get_requisitions() requisitions_data = await gocardless_service.get_requisitions()
all_accounts = set() all_accounts = set()
for req in requisitions_data.get("results", []): for req in requisitions_data.get("results", []):
all_accounts.update(req.get("accounts", [])) all_accounts.update(req.get("accounts", []))
# Filter by specific account if requested # Filter by specific account if requested
if account_id: if account_id:
if account_id not in all_accounts: if account_id not in all_accounts:
raise HTTPException(status_code=404, detail="Account not found") raise HTTPException(status_code=404, detail="Account not found")
all_accounts = {account_id} all_accounts = {account_id}
all_transactions = [] all_transactions = []
# Collect transactions from all accounts # Collect transactions from all accounts
for acc_id in all_accounts: for acc_id in all_accounts:
try: try:
account_details = await gocardless_service.get_account_details(acc_id) account_details = await gocardless_service.get_account_details(acc_id)
transactions_data = await gocardless_service.get_account_transactions(acc_id) transactions_data = await gocardless_service.get_account_transactions(
acc_id
)
processed_transactions = database_service.process_transactions( processed_transactions = database_service.process_transactions(
acc_id, account_details, transactions_data acc_id, account_details, transactions_data
) )
all_transactions.extend(processed_transactions) all_transactions.extend(processed_transactions)
except Exception as e: except Exception as e:
logger.error(f"Failed to get transactions for account {acc_id}: {e}") logger.error(f"Failed to get transactions for account {acc_id}: {e}")
continue continue
# Apply filters # Apply filters
filtered_transactions = all_transactions filtered_transactions = all_transactions
# Date range filter # Date range filter
if date_from: if date_from:
from_date = datetime.fromisoformat(date_from) from_date = datetime.fromisoformat(date_from)
filtered_transactions = [ filtered_transactions = [
txn for txn in filtered_transactions txn
for txn in filtered_transactions
if txn["transactionDate"] >= from_date if txn["transactionDate"] >= from_date
] ]
if date_to: if date_to:
to_date = datetime.fromisoformat(date_to) to_date = datetime.fromisoformat(date_to)
filtered_transactions = [ filtered_transactions = [
txn for txn in filtered_transactions txn
for txn in filtered_transactions
if txn["transactionDate"] <= to_date if txn["transactionDate"] <= to_date
] ]
# Amount filters # Amount filters
if min_amount is not None: if min_amount is not None:
filtered_transactions = [ filtered_transactions = [
txn for txn in filtered_transactions txn
for txn in filtered_transactions
if txn["transactionValue"] >= min_amount if txn["transactionValue"] >= min_amount
] ]
if max_amount is not None: if max_amount is not None:
filtered_transactions = [ filtered_transactions = [
txn for txn in filtered_transactions txn
for txn in filtered_transactions
if txn["transactionValue"] <= max_amount if txn["transactionValue"] <= max_amount
] ]
# Search filter # Search filter
if search: if search:
search_lower = search.lower() search_lower = search.lower()
filtered_transactions = [ filtered_transactions = [
txn for txn in filtered_transactions txn
for txn in filtered_transactions
if search_lower in txn["description"].lower() if search_lower in txn["description"].lower()
] ]
# Sort by date (newest first) # Sort by date (newest first)
filtered_transactions.sort( filtered_transactions.sort(key=lambda x: x["transactionDate"], reverse=True)
key=lambda x: x["transactionDate"],
reverse=True
)
# Apply pagination # Apply pagination
total_transactions = len(filtered_transactions) total_transactions = len(filtered_transactions)
paginated_transactions = filtered_transactions[offset:offset + limit] paginated_transactions = filtered_transactions[offset : offset + limit]
if summary_only: if summary_only:
# Return simplified transaction summaries # Return simplified transaction summaries
data = [ data = [
@@ -116,7 +132,7 @@ async def get_all_transactions(
amount=txn["transactionValue"], amount=txn["transactionValue"],
currency=txn["transactionCurrency"], currency=txn["transactionCurrency"],
status=txn["transactionStatus"], status=txn["transactionStatus"],
account_id=txn["accountId"] account_id=txn["accountId"],
) )
for txn in paginated_transactions for txn in paginated_transactions
] ]
@@ -133,86 +149,99 @@ async def get_all_transactions(
transaction_value=txn["transactionValue"], transaction_value=txn["transactionValue"],
transaction_currency=txn["transactionCurrency"], transaction_currency=txn["transactionCurrency"],
transaction_status=txn["transactionStatus"], transaction_status=txn["transactionStatus"],
raw_transaction=txn["rawTransaction"] raw_transaction=txn["rawTransaction"],
) )
for txn in paginated_transactions for txn in paginated_transactions
] ]
return APIResponse( return APIResponse(
success=True, success=True,
data=data, data=data,
message=f"Retrieved {len(data)} transactions (showing {offset + 1}-{offset + len(data)} of {total_transactions})" message=f"Retrieved {len(data)} transactions (showing {offset + 1}-{offset + len(data)} of {total_transactions})",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get transactions: {e}") logger.error(f"Failed to get transactions: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get transactions: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to get transactions: {str(e)}"
)
@router.get("/transactions/stats", response_model=APIResponse) @router.get("/transactions/stats", response_model=APIResponse)
async def get_transaction_stats( async def get_transaction_stats(
days: int = Query(default=30, description="Number of days to include in stats"), days: int = Query(default=30, description="Number of days to include in stats"),
account_id: Optional[str] = Query(default=None, description="Filter by account ID") account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
) -> APIResponse: ) -> APIResponse:
"""Get transaction statistics for the last N days""" """Get transaction statistics for the last N days"""
try: try:
# Date range for stats # Date range for stats
end_date = datetime.now() end_date = datetime.now()
start_date = end_date - timedelta(days=days) start_date = end_date - timedelta(days=days)
# Get all transactions (reuse the existing endpoint logic) # Get all transactions (reuse the existing endpoint logic)
# This is a simplified implementation - in practice you might want to optimize this # This is a simplified implementation - in practice you might want to optimize this
requisitions_data = await gocardless_service.get_requisitions() requisitions_data = await gocardless_service.get_requisitions()
all_accounts = set() all_accounts = set()
for req in requisitions_data.get("results", []): for req in requisitions_data.get("results", []):
all_accounts.update(req.get("accounts", [])) all_accounts.update(req.get("accounts", []))
if account_id: if account_id:
if account_id not in all_accounts: if account_id not in all_accounts:
raise HTTPException(status_code=404, detail="Account not found") raise HTTPException(status_code=404, detail="Account not found")
all_accounts = {account_id} all_accounts = {account_id}
all_transactions = [] all_transactions = []
for acc_id in all_accounts: for acc_id in all_accounts:
try: try:
account_details = await gocardless_service.get_account_details(acc_id) account_details = await gocardless_service.get_account_details(acc_id)
transactions_data = await gocardless_service.get_account_transactions(acc_id) transactions_data = await gocardless_service.get_account_transactions(
acc_id
)
processed_transactions = database_service.process_transactions( processed_transactions = database_service.process_transactions(
acc_id, account_details, transactions_data acc_id, account_details, transactions_data
) )
all_transactions.extend(processed_transactions) all_transactions.extend(processed_transactions)
except Exception as e: except Exception as e:
logger.error(f"Failed to get transactions for account {acc_id}: {e}") logger.error(f"Failed to get transactions for account {acc_id}: {e}")
continue continue
# Filter transactions by date range # Filter transactions by date range
recent_transactions = [ recent_transactions = [
txn for txn in all_transactions txn
for txn in all_transactions
if start_date <= txn["transactionDate"] <= end_date if start_date <= txn["transactionDate"] <= end_date
] ]
# Calculate stats # Calculate stats
total_transactions = len(recent_transactions) total_transactions = len(recent_transactions)
total_income = sum( total_income = sum(
txn["transactionValue"] txn["transactionValue"]
for txn in recent_transactions for txn in recent_transactions
if txn["transactionValue"] > 0 if txn["transactionValue"] > 0
) )
total_expenses = sum( total_expenses = sum(
abs(txn["transactionValue"]) abs(txn["transactionValue"])
for txn in recent_transactions for txn in recent_transactions
if txn["transactionValue"] < 0 if txn["transactionValue"] < 0
) )
net_change = total_income - total_expenses net_change = total_income - total_expenses
# Count by status # Count by status
booked_count = len([txn for txn in recent_transactions if txn["transactionStatus"] == "booked"]) booked_count = len(
pending_count = len([txn for txn in recent_transactions if txn["transactionStatus"] == "pending"]) [txn for txn in recent_transactions if txn["transactionStatus"] == "booked"]
)
pending_count = len(
[
txn
for txn in recent_transactions
if txn["transactionStatus"] == "pending"
]
)
stats = { stats = {
"period_days": days, "period_days": days,
"total_transactions": total_transactions, "total_transactions": total_transactions,
@@ -222,17 +251,23 @@ async def get_transaction_stats(
"total_expenses": round(total_expenses, 2), "total_expenses": round(total_expenses, 2),
"net_change": round(net_change, 2), "net_change": round(net_change, 2),
"average_transaction": round( "average_transaction": round(
sum(txn["transactionValue"] for txn in recent_transactions) / total_transactions, 2 sum(txn["transactionValue"] for txn in recent_transactions)
) if total_transactions > 0 else 0, / total_transactions,
"accounts_included": len(all_accounts) 2,
)
if total_transactions > 0
else 0,
"accounts_included": len(all_accounts),
} }
return APIResponse( return APIResponse(
success=True, success=True,
data=stats, data=stats,
message=f"Transaction statistics for last {days} days" message=f"Transaction statistics for last {days} days",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get transaction stats: {e}") logger.error(f"Failed to get transaction stats: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get transaction stats: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to get transaction stats: {str(e)}"
)

View File

@@ -4,47 +4,30 @@ from loguru import logger
from leggend.config import config from leggend.config import config
from leggend.services.sync_service import SyncService from leggend.services.sync_service import SyncService
from leggend.services.notification_service import NotificationService
class BackgroundScheduler: class BackgroundScheduler:
def __init__(self): def __init__(self):
self.scheduler = AsyncIOScheduler() self.scheduler = AsyncIOScheduler()
self.sync_service = SyncService() self.sync_service = SyncService()
self.notification_service = NotificationService()
self.max_retries = 3
self.retry_delay = 300 # 5 minutes
def start(self): def start(self):
"""Start the scheduler and configure sync jobs based on configuration""" """Start the scheduler and configure sync jobs based on configuration"""
schedule_config = config.scheduler_config.get("sync", {}) schedule_config = config.scheduler_config.get("sync", {})
if not schedule_config.get("enabled", True): if not schedule_config.get("enabled", True):
logger.info("Sync scheduling is disabled in configuration") logger.info("Sync scheduling is disabled in configuration")
self.scheduler.start() self.scheduler.start()
return return
# Use custom cron expression if provided, otherwise use hour/minute # Parse schedule configuration
if schedule_config.get("cron"): trigger = self._parse_cron_config(schedule_config)
# Parse custom cron expression (e.g., "0 3 * * *" for daily at 3 AM) if not trigger:
try: return
cron_parts = schedule_config["cron"].split()
if len(cron_parts) == 5:
minute, hour, day, month, day_of_week = cron_parts
trigger = CronTrigger(
minute=minute,
hour=hour,
day=day if day != "*" else None,
month=month if month != "*" else None,
day_of_week=day_of_week if day_of_week != "*" else None,
)
else:
logger.error(f"Invalid cron expression: {schedule_config['cron']}")
return
except Exception as e:
logger.error(f"Error parsing cron expression: {e}")
return
else:
# Use hour/minute configuration (default: 3:00 AM daily)
hour = schedule_config.get("hour", 3)
minute = schedule_config.get("minute", 0)
trigger = CronTrigger(hour=hour, minute=minute)
self.scheduler.add_job( self.scheduler.add_job(
self._run_sync, self._run_sync,
@@ -53,7 +36,7 @@ class BackgroundScheduler:
name="Scheduled sync of all transactions", name="Scheduled sync of all transactions",
max_instances=1, max_instances=1,
) )
self.scheduler.start() self.scheduler.start()
logger.info(f"Background scheduler started with sync job: {trigger}") logger.info(f"Background scheduler started with sync job: {trigger}")
@@ -76,28 +59,9 @@ class BackgroundScheduler:
return return
# Configure new schedule # Configure new schedule
if schedule_config.get("cron"): trigger = self._parse_cron_config(schedule_config)
try: if not trigger:
cron_parts = schedule_config["cron"].split() return
if len(cron_parts) == 5:
minute, hour, day, month, day_of_week = cron_parts
trigger = CronTrigger(
minute=minute,
hour=hour,
day=day if day != "*" else None,
month=month if month != "*" else None,
day_of_week=day_of_week if day_of_week != "*" else None,
)
else:
logger.error(f"Invalid cron expression: {schedule_config['cron']}")
return
except Exception as e:
logger.error(f"Error parsing cron expression: {e}")
return
else:
hour = schedule_config.get("hour", 3)
minute = schedule_config.get("minute", 0)
trigger = CronTrigger(hour=hour, minute=minute)
self.scheduler.add_job( self.scheduler.add_job(
self._run_sync, self._run_sync,
@@ -108,13 +72,90 @@ class BackgroundScheduler:
) )
logger.info(f"Rescheduled sync job with: {trigger}") logger.info(f"Rescheduled sync job with: {trigger}")
async def _run_sync(self): def _parse_cron_config(self, schedule_config: dict) -> CronTrigger:
"""Parse cron configuration and return CronTrigger"""
if schedule_config.get("cron"):
# Parse custom cron expression (e.g., "0 3 * * *" for daily at 3 AM)
try:
cron_parts = schedule_config["cron"].split()
if len(cron_parts) == 5:
minute, hour, day, month, day_of_week = cron_parts
return CronTrigger(
minute=minute,
hour=hour,
day=day if day != "*" else None,
month=month if month != "*" else None,
day_of_week=day_of_week if day_of_week != "*" else None,
)
else:
logger.error(f"Invalid cron expression: {schedule_config['cron']}")
return None
except Exception as e:
logger.error(f"Error parsing cron expression: {e}")
return None
else:
# Use hour/minute configuration (default: 3:00 AM daily)
hour = schedule_config.get("hour", 3)
minute = schedule_config.get("minute", 0)
return CronTrigger(hour=hour, minute=minute)
async def _run_sync(self, retry_count: int = 0):
"""Run sync with enhanced error handling and retry logic"""
try: try:
logger.info("Starting scheduled sync job") logger.info("Starting scheduled sync job")
await self.sync_service.sync_all_accounts() await self.sync_service.sync_all_accounts()
logger.info("Scheduled sync job completed successfully") logger.info("Scheduled sync job completed successfully")
except Exception as e: except Exception as e:
logger.error(f"Scheduled sync job failed: {e}") logger.error(
f"Scheduled sync job failed (attempt {retry_count + 1}/{self.max_retries}): {e}"
)
# Send notification about the failure
try:
await self.notification_service.send_expiry_notification(
{
"type": "sync_failure",
"error": str(e),
"retry_count": retry_count + 1,
"max_retries": self.max_retries,
}
)
except Exception as notification_error:
logger.error(
f"Failed to send failure notification: {notification_error}"
)
# Implement retry logic for transient failures
if retry_count < self.max_retries - 1:
import datetime
logger.info(f"Retrying sync job in {self.retry_delay} seconds...")
# Schedule a retry
retry_time = datetime.datetime.now() + datetime.timedelta(
seconds=self.retry_delay
)
self.scheduler.add_job(
self._run_sync,
"date",
args=[retry_count + 1],
id=f"sync_retry_{retry_count + 1}",
run_date=retry_time,
)
else:
logger.error("Maximum retries exceeded for sync job")
# Send final failure notification
try:
await self.notification_service.send_expiry_notification(
{
"type": "sync_final_failure",
"error": str(e),
"retry_count": retry_count + 1,
}
)
except Exception as notification_error:
logger.error(
f"Failed to send final failure notification: {notification_error}"
)
def get_next_sync_time(self): def get_next_sync_time(self):
"""Get the next scheduled sync time""" """Get the next scheduled sync time"""
@@ -124,4 +165,4 @@ class BackgroundScheduler:
return None return None
scheduler = BackgroundScheduler() scheduler = BackgroundScheduler()

View File

@@ -24,7 +24,7 @@ class Config:
if config_path is None: if config_path is None:
config_path = os.environ.get( config_path = os.environ.get(
"LEGGEN_CONFIG_FILE", "LEGGEN_CONFIG_FILE",
str(Path.home() / ".config" / "leggen" / "config.toml") str(Path.home() / ".config" / "leggen" / "config.toml"),
) )
self._config_path = config_path self._config_path = config_path
@@ -42,15 +42,17 @@ class Config:
return self._config return self._config
def save_config(self, config_data: Dict[str, Any] = None, config_path: str = None) -> None: def save_config(
self, config_data: Dict[str, Any] = None, config_path: str = None
) -> None:
"""Save configuration to TOML file""" """Save configuration to TOML file"""
if config_data is None: if config_data is None:
config_data = self._config config_data = self._config
if config_path is None: if config_path is None:
config_path = self._config_path or os.environ.get( config_path = self._config_path or os.environ.get(
"LEGGEN_CONFIG_FILE", "LEGGEN_CONFIG_FILE",
str(Path.home() / ".config" / "leggen" / "config.toml") str(Path.home() / ".config" / "leggen" / "config.toml"),
) )
# Ensure directory exists # Ensure directory exists
@@ -59,7 +61,7 @@ class Config:
try: try:
with open(config_path, "wb") as f: with open(config_path, "wb") as f:
tomli_w.dump(config_data, f) tomli_w.dump(config_data, f)
# Update in-memory config # Update in-memory config
self._config = config_data self._config = config_data
self._config_path = config_path self._config_path = config_path
@@ -72,10 +74,10 @@ class Config:
"""Update a specific configuration value""" """Update a specific configuration value"""
if self._config is None: if self._config is None:
self.load_config() self.load_config()
if section not in self._config: if section not in self._config:
self._config[section] = {} self._config[section] = {}
self._config[section][key] = value self._config[section][key] = value
self.save_config() self.save_config()
@@ -83,7 +85,7 @@ class Config:
"""Update an entire configuration section""" """Update an entire configuration section"""
if self._config is None: if self._config is None:
self.load_config() self.load_config()
self._config[section] = data self._config[section] = data
self.save_config() self.save_config()
@@ -117,10 +119,10 @@ class Config:
"enabled": True, "enabled": True,
"hour": 3, "hour": 3,
"minute": 0, "minute": 0,
"cron": None # Optional custom cron expression "cron": None, # Optional custom cron expression
} }
} }
return self.config.get("scheduler", default_schedule) return self.config.get("scheduler", default_schedule)
config = Config() config = Config()

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from importlib import metadata
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
@@ -15,7 +16,7 @@ from leggend.config import config
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# Startup # Startup
logger.info("Starting leggend service...") logger.info("Starting leggend service...")
# Load configuration # Load configuration
try: try:
config.load_config() config.load_config()
@@ -27,26 +28,35 @@ async def lifespan(app: FastAPI):
# Start background scheduler # Start background scheduler
scheduler.start() scheduler.start()
logger.info("Background scheduler started") logger.info("Background scheduler started")
yield yield
# Shutdown # Shutdown
logger.info("Shutting down leggend service...") logger.info("Shutting down leggend service...")
scheduler.shutdown() scheduler.shutdown()
def create_app() -> FastAPI: def create_app() -> FastAPI:
# Get version dynamically from package metadata
try:
version = metadata.version("leggen")
except metadata.PackageNotFoundError:
version = "unknown"
app = FastAPI( app = FastAPI(
title="Leggend API", title="Leggend API",
description="Open Banking API for Leggen", description="Open Banking API for Leggen",
version="0.6.11", version=version,
lifespan=lifespan, lifespan=lifespan,
) )
# Add CORS middleware # Add CORS middleware
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["http://localhost:3000", "http://localhost:5173"], # SvelteKit dev servers allow_origins=[
"http://localhost:3000",
"http://localhost:5173",
], # SvelteKit dev servers
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
@@ -60,7 +70,12 @@ def create_app() -> FastAPI:
@app.get("/") @app.get("/")
async def root(): async def root():
return {"message": "Leggend API is running", "version": "0.6.11"} # Get version dynamically
try:
version = metadata.version("leggen")
except metadata.PackageNotFoundError:
version = "unknown"
return {"message": "Leggend API is running", "version": version}
@app.get("/health") @app.get("/health")
async def health(): async def health():
@@ -71,25 +86,19 @@ def create_app() -> FastAPI:
def main(): def main():
import argparse import argparse
parser = argparse.ArgumentParser(description="Start the Leggend API service") parser = argparse.ArgumentParser(description="Start the Leggend API service")
parser.add_argument( parser.add_argument(
"--reload", "--reload", action="store_true", help="Enable auto-reload for development"
action="store_true",
help="Enable auto-reload for development"
) )
parser.add_argument( parser.add_argument(
"--host", "--host", default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)"
default="0.0.0.0",
help="Host to bind to (default: 0.0.0.0)"
) )
parser.add_argument( parser.add_argument(
"--port", "--port", type=int, default=8000, help="Port to bind to (default: 8000)"
type=int,
default=8000,
help="Port to bind to (default: 8000)"
) )
args = parser.parse_args() args = parser.parse_args()
if args.reload: if args.reload:
# Use string import for reload to work properly # Use string import for reload to work properly
uvicorn.run( uvicorn.run(
@@ -114,4 +123,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -11,7 +11,9 @@ class DatabaseService:
self.db_config = config.database_config self.db_config = config.database_config
self.sqlite_enabled = self.db_config.get("sqlite", True) self.sqlite_enabled = self.db_config.get("sqlite", True)
async def persist_balance(self, account_id: str, balance_data: Dict[str, Any]) -> None: async def persist_balance(
self, account_id: str, balance_data: Dict[str, Any]
) -> None:
"""Persist account balance data""" """Persist account balance data"""
if not self.sqlite_enabled: if not self.sqlite_enabled:
logger.warning("SQLite database disabled, skipping balance persistence") logger.warning("SQLite database disabled, skipping balance persistence")
@@ -19,7 +21,9 @@ class DatabaseService:
await self._persist_balance_sqlite(account_id, balance_data) await self._persist_balance_sqlite(account_id, balance_data)
async def persist_transactions(self, account_id: str, transactions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: async def persist_transactions(
self, account_id: str, transactions: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Persist transactions and return new transactions""" """Persist transactions and return new transactions"""
if not self.sqlite_enabled: if not self.sqlite_enabled:
logger.warning("SQLite database disabled, skipping transaction persistence") logger.warning("SQLite database disabled, skipping transaction persistence")
@@ -27,32 +31,48 @@ class DatabaseService:
return await self._persist_transactions_sqlite(account_id, transactions) return await self._persist_transactions_sqlite(account_id, transactions)
def process_transactions(self, account_id: str, account_info: Dict[str, Any], transaction_data: Dict[str, Any]) -> List[Dict[str, Any]]: def process_transactions(
self,
account_id: str,
account_info: Dict[str, Any],
transaction_data: Dict[str, Any],
) -> List[Dict[str, Any]]:
"""Process raw transaction data into standardized format""" """Process raw transaction data into standardized format"""
transactions = [] transactions = []
# Process booked transactions # Process booked transactions
for transaction in transaction_data.get("transactions", {}).get("booked", []): for transaction in transaction_data.get("transactions", {}).get("booked", []):
processed = self._process_single_transaction(account_id, account_info, transaction, "booked") processed = self._process_single_transaction(
account_id, account_info, transaction, "booked"
)
transactions.append(processed) transactions.append(processed)
# Process pending transactions # Process pending transactions
for transaction in transaction_data.get("transactions", {}).get("pending", []): for transaction in transaction_data.get("transactions", {}).get("pending", []):
processed = self._process_single_transaction(account_id, account_info, transaction, "pending") processed = self._process_single_transaction(
account_id, account_info, transaction, "pending"
)
transactions.append(processed) transactions.append(processed)
return transactions return transactions
def _process_single_transaction(self, account_id: str, account_info: Dict[str, Any], transaction: Dict[str, Any], status: str) -> Dict[str, Any]: def _process_single_transaction(
self,
account_id: str,
account_info: Dict[str, Any],
transaction: Dict[str, Any],
status: str,
) -> Dict[str, Any]:
"""Process a single transaction into standardized format""" """Process a single transaction into standardized format"""
# Extract dates # Extract dates
booked_date = transaction.get("bookingDateTime") or transaction.get("bookingDate") booked_date = transaction.get("bookingDateTime") or transaction.get(
"bookingDate"
)
value_date = transaction.get("valueDateTime") or transaction.get("valueDate") value_date = transaction.get("valueDateTime") or transaction.get("valueDate")
if booked_date and value_date: if booked_date and value_date:
min_date = min( min_date = min(
datetime.fromisoformat(booked_date), datetime.fromisoformat(booked_date), datetime.fromisoformat(value_date)
datetime.fromisoformat(value_date)
) )
else: else:
min_date = datetime.fromisoformat(booked_date or value_date) min_date = datetime.fromisoformat(booked_date or value_date)
@@ -65,7 +85,7 @@ class DatabaseService:
# Extract description # Extract description
description = transaction.get( description = transaction.get(
"remittanceInformationUnstructured", "remittanceInformationUnstructured",
",".join(transaction.get("remittanceInformationUnstructuredArray", [])) ",".join(transaction.get("remittanceInformationUnstructuredArray", [])),
) )
return { return {
@@ -81,13 +101,19 @@ class DatabaseService:
"rawTransaction": transaction, "rawTransaction": transaction,
} }
async def _persist_balance_sqlite(self, account_id: str, balance_data: Dict[str, Any]) -> None: async def _persist_balance_sqlite(
self, account_id: str, balance_data: Dict[str, Any]
) -> None:
"""Persist balance to SQLite - placeholder implementation""" """Persist balance to SQLite - placeholder implementation"""
# Would import and use leggen.database.sqlite # Would import and use leggen.database.sqlite
logger.info(f"Persisting balance to SQLite for account {account_id}") logger.info(f"Persisting balance to SQLite for account {account_id}")
async def _persist_transactions_sqlite(self, account_id: str, transactions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: async def _persist_transactions_sqlite(
self, account_id: str, transactions: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Persist transactions to SQLite - placeholder implementation""" """Persist transactions to SQLite - placeholder implementation"""
# Would import and use leggen.database.sqlite # Would import and use leggen.database.sqlite
logger.info(f"Persisting {len(transactions)} transactions to SQLite for account {account_id}") logger.info(
return transactions # Return new transactions for notifications f"Persisting {len(transactions)} transactions to SQLite for account {account_id}"
)
return transactions # Return new transactions for notifications

View File

@@ -12,37 +12,36 @@ from leggend.config import config
class GoCardlessService: class GoCardlessService:
def __init__(self): def __init__(self):
self.config = config.gocardless_config self.config = config.gocardless_config
self.base_url = self.config.get("url", "https://bankaccountdata.gocardless.com/api/v2") self.base_url = self.config.get(
"url", "https://bankaccountdata.gocardless.com/api/v2"
)
self._token = None self._token = None
async def _get_auth_headers(self) -> Dict[str, str]: async def _get_auth_headers(self) -> Dict[str, str]:
"""Get authentication headers for GoCardless API""" """Get authentication headers for GoCardless API"""
token = await self._get_token() token = await self._get_token()
return { return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
}
async def _get_token(self) -> str: async def _get_token(self) -> str:
"""Get access token for GoCardless API""" """Get access token for GoCardless API"""
if self._token: if self._token:
return self._token return self._token
# Use ~/.config/leggen for consistency with main config # Use ~/.config/leggen for consistency with main config
auth_file = Path.home() / ".config" / "leggen" / "auth.json" auth_file = Path.home() / ".config" / "leggen" / "auth.json"
if auth_file.exists(): if auth_file.exists():
try: try:
with open(auth_file, "r") as f: with open(auth_file, "r") as f:
auth = json.load(f) auth = json.load(f)
if auth.get("access"): if auth.get("access"):
# Try to refresh the token # Try to refresh the token
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
try: try:
response = await client.post( response = await client.post(
f"{self.base_url}/token/refresh/", f"{self.base_url}/token/refresh/",
json={"refresh": auth["refresh"]} json={"refresh": auth["refresh"]},
) )
response.raise_for_status() response.raise_for_status()
auth.update(response.json()) auth.update(response.json())
@@ -84,7 +83,7 @@ class GoCardlessService:
"""Save authentication data to file""" """Save authentication data to file"""
auth_file = Path.home() / ".config" / "leggen" / "auth.json" auth_file = Path.home() / ".config" / "leggen" / "auth.json"
auth_file.parent.mkdir(parents=True, exist_ok=True) auth_file.parent.mkdir(parents=True, exist_ok=True)
with open(auth_file, "w") as f: with open(auth_file, "w") as f:
json.dump(auth_data, f) json.dump(auth_data, f)
@@ -95,22 +94,21 @@ class GoCardlessService:
response = await client.get( response = await client.get(
f"{self.base_url}/institutions/", f"{self.base_url}/institutions/",
headers=headers, headers=headers,
params={"country": country} params={"country": country},
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
async def create_requisition(self, institution_id: str, redirect_url: str) -> Dict[str, Any]: async def create_requisition(
self, institution_id: str, redirect_url: str
) -> Dict[str, Any]:
"""Create a bank connection requisition""" """Create a bank connection requisition"""
headers = await self._get_auth_headers() headers = await self._get_auth_headers()
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.base_url}/requisitions/", f"{self.base_url}/requisitions/",
headers=headers, headers=headers,
json={ json={"institution_id": institution_id, "redirect": redirect_url},
"institution_id": institution_id,
"redirect": redirect_url
}
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
@@ -120,8 +118,7 @@ class GoCardlessService:
headers = await self._get_auth_headers() headers = await self._get_auth_headers()
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get( response = await client.get(
f"{self.base_url}/requisitions/", f"{self.base_url}/requisitions/", headers=headers
headers=headers
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
@@ -131,8 +128,7 @@ class GoCardlessService:
headers = await self._get_auth_headers() headers = await self._get_auth_headers()
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get( response = await client.get(
f"{self.base_url}/accounts/{account_id}/", f"{self.base_url}/accounts/{account_id}/", headers=headers
headers=headers
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
@@ -142,8 +138,7 @@ class GoCardlessService:
headers = await self._get_auth_headers() headers = await self._get_auth_headers()
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get( response = await client.get(
f"{self.base_url}/accounts/{account_id}/balances/", f"{self.base_url}/accounts/{account_id}/balances/", headers=headers
headers=headers
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
@@ -153,8 +148,7 @@ class GoCardlessService:
headers = await self._get_auth_headers() headers = await self._get_auth_headers()
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get( response = await client.get(
f"{self.base_url}/accounts/{account_id}/transactions/", f"{self.base_url}/accounts/{account_id}/transactions/", headers=headers
headers=headers
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()

View File

@@ -10,7 +10,9 @@ class NotificationService:
self.notifications_config = config.notifications_config self.notifications_config = config.notifications_config
self.filters_config = config.filters_config self.filters_config = config.filters_config
async def send_transaction_notifications(self, transactions: List[Dict[str, Any]]) -> None: async def send_transaction_notifications(
self, transactions: List[Dict[str, Any]]
) -> None:
"""Send notifications for new transactions that match filters""" """Send notifications for new transactions that match filters"""
if not self.filters_config: if not self.filters_config:
logger.info("No notification filters configured, skipping notifications") logger.info("No notification filters configured, skipping notifications")
@@ -18,7 +20,7 @@ class NotificationService:
# Filter transactions that match notification criteria # Filter transactions that match notification criteria
matching_transactions = self._filter_transactions(transactions) matching_transactions = self._filter_transactions(transactions)
if not matching_transactions: if not matching_transactions:
logger.info("No transactions matched notification filters") logger.info("No transactions matched notification filters")
return return
@@ -26,7 +28,7 @@ class NotificationService:
# Send to enabled notification services # Send to enabled notification services
if self._is_discord_enabled(): if self._is_discord_enabled():
await self._send_discord_notifications(matching_transactions) await self._send_discord_notifications(matching_transactions)
if self._is_telegram_enabled(): if self._is_telegram_enabled():
await self._send_telegram_notifications(matching_transactions) await self._send_telegram_notifications(matching_transactions)
@@ -40,7 +42,9 @@ class NotificationService:
await self._send_telegram_test(message) await self._send_telegram_test(message)
return True return True
else: else:
logger.error(f"Notification service '{service}' not enabled or not found") logger.error(
f"Notification service '{service}' not enabled or not found"
)
return False return False
except Exception as e: except Exception as e:
logger.error(f"Failed to send test notification to {service}: {e}") logger.error(f"Failed to send test notification to {service}: {e}")
@@ -50,54 +54,66 @@ class NotificationService:
"""Send notification about account expiry""" """Send notification about account expiry"""
if self._is_discord_enabled(): if self._is_discord_enabled():
await self._send_discord_expiry(notification_data) await self._send_discord_expiry(notification_data)
if self._is_telegram_enabled(): if self._is_telegram_enabled():
await self._send_telegram_expiry(notification_data) await self._send_telegram_expiry(notification_data)
def _filter_transactions(self, transactions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def _filter_transactions(
self, transactions: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Filter transactions based on notification criteria""" """Filter transactions based on notification criteria"""
matching = [] matching = []
filters_case_insensitive = self.filters_config.get("case-insensitive", {}) filters_case_insensitive = self.filters_config.get("case-insensitive", {})
for transaction in transactions: for transaction in transactions:
description = transaction.get("description", "").lower() description = transaction.get("description", "").lower()
# Check case-insensitive filters # Check case-insensitive filters
for filter_name, filter_value in filters_case_insensitive.items(): for filter_name, filter_value in filters_case_insensitive.items():
if filter_value.lower() in description: if filter_value.lower() in description:
matching.append({ matching.append(
"name": transaction["description"], {
"value": transaction["transactionValue"], "name": transaction["description"],
"currency": transaction["transactionCurrency"], "value": transaction["transactionValue"],
"date": transaction["transactionDate"], "currency": transaction["transactionCurrency"],
}) "date": transaction["transactionDate"],
}
)
break break
return matching return matching
def _is_discord_enabled(self) -> bool: def _is_discord_enabled(self) -> bool:
"""Check if Discord notifications are enabled""" """Check if Discord notifications are enabled"""
discord_config = self.notifications_config.get("discord", {}) discord_config = self.notifications_config.get("discord", {})
return bool(discord_config.get("webhook") and discord_config.get("enabled", True)) return bool(
discord_config.get("webhook") and discord_config.get("enabled", True)
)
def _is_telegram_enabled(self) -> bool: def _is_telegram_enabled(self) -> bool:
"""Check if Telegram notifications are enabled""" """Check if Telegram notifications are enabled"""
telegram_config = self.notifications_config.get("telegram", {}) telegram_config = self.notifications_config.get("telegram", {})
return bool( return bool(
telegram_config.get("token") and telegram_config.get("token")
telegram_config.get("chat_id") and and telegram_config.get("chat_id")
telegram_config.get("enabled", True) and telegram_config.get("enabled", True)
) )
async def _send_discord_notifications(self, transactions: List[Dict[str, Any]]) -> None: async def _send_discord_notifications(
self, transactions: List[Dict[str, Any]]
) -> None:
"""Send Discord notifications - placeholder implementation""" """Send Discord notifications - placeholder implementation"""
# Would import and use leggen.notifications.discord # Would import and use leggen.notifications.discord
logger.info(f"Sending {len(transactions)} transaction notifications to Discord") logger.info(f"Sending {len(transactions)} transaction notifications to Discord")
async def _send_telegram_notifications(self, transactions: List[Dict[str, Any]]) -> None: async def _send_telegram_notifications(
self, transactions: List[Dict[str, Any]]
) -> None:
"""Send Telegram notifications - placeholder implementation""" """Send Telegram notifications - placeholder implementation"""
# Would import and use leggen.notifications.telegram # Would import and use leggen.notifications.telegram
logger.info(f"Sending {len(transactions)} transaction notifications to Telegram") logger.info(
f"Sending {len(transactions)} transaction notifications to Telegram"
)
async def _send_discord_test(self, message: str) -> None: async def _send_discord_test(self, message: str) -> None:
"""Send Discord test notification""" """Send Discord test notification"""
@@ -113,4 +129,4 @@ class NotificationService:
async def _send_telegram_expiry(self, notification_data: Dict[str, Any]) -> None: async def _send_telegram_expiry(self, notification_data: Dict[str, Any]) -> None:
"""Send Telegram expiry notification""" """Send Telegram expiry notification"""
logger.info(f"Sending Telegram expiry notification: {notification_data}") logger.info(f"Sending Telegram expiry notification: {notification_data}")

View File

@@ -30,7 +30,7 @@ class SyncService:
start_time = datetime.now() start_time = datetime.now()
self._sync_status.is_running = True self._sync_status.is_running = True
self._sync_status.errors = [] self._sync_status.errors = []
accounts_processed = 0 accounts_processed = 0
transactions_added = 0 transactions_added = 0
transactions_updated = 0 transactions_updated = 0
@@ -39,22 +39,24 @@ class SyncService:
try: try:
logger.info("Starting sync of all accounts") logger.info("Starting sync of all accounts")
# Get all requisitions and accounts # Get all requisitions and accounts
requisitions = await self.gocardless.get_requisitions() requisitions = await self.gocardless.get_requisitions()
all_accounts = set() all_accounts = set()
for req in requisitions.get("results", []): for req in requisitions.get("results", []):
all_accounts.update(req.get("accounts", [])) all_accounts.update(req.get("accounts", []))
self._sync_status.total_accounts = len(all_accounts) self._sync_status.total_accounts = len(all_accounts)
# Process each account # Process each account
for account_id in all_accounts: for account_id in all_accounts:
try: try:
# Get account details # Get account details
account_details = await self.gocardless.get_account_details(account_id) account_details = await self.gocardless.get_account_details(
account_id
)
# Get and save balances # Get and save balances
balances = await self.gocardless.get_account_balances(account_id) balances = await self.gocardless.get_account_balances(account_id)
if balances: if balances:
@@ -62,7 +64,9 @@ class SyncService:
balances_updated += len(balances.get("balances", [])) balances_updated += len(balances.get("balances", []))
# Get and save transactions # Get and save transactions
transactions = await self.gocardless.get_account_transactions(account_id) transactions = await self.gocardless.get_account_transactions(
account_id
)
if transactions: if transactions:
processed_transactions = self.database.process_transactions( processed_transactions = self.database.process_transactions(
account_id, account_details, transactions account_id, account_details, transactions
@@ -71,16 +75,18 @@ class SyncService:
account_id, processed_transactions account_id, processed_transactions
) )
transactions_added += len(new_transactions) transactions_added += len(new_transactions)
# Send notifications for new transactions # Send notifications for new transactions
if new_transactions: if new_transactions:
await self.notifications.send_transaction_notifications(new_transactions) await self.notifications.send_transaction_notifications(
new_transactions
)
accounts_processed += 1 accounts_processed += 1
self._sync_status.accounts_synced = accounts_processed self._sync_status.accounts_synced = accounts_processed
logger.info(f"Synced account {account_id} successfully") logger.info(f"Synced account {account_id} successfully")
except Exception as e: except Exception as e:
error_msg = f"Failed to sync account {account_id}: {str(e)}" error_msg = f"Failed to sync account {account_id}: {str(e)}"
errors.append(error_msg) errors.append(error_msg)
@@ -88,9 +94,9 @@ class SyncService:
end_time = datetime.now() end_time = datetime.now()
duration = (end_time - start_time).total_seconds() duration = (end_time - start_time).total_seconds()
self._sync_status.last_sync = end_time self._sync_status.last_sync = end_time
result = SyncResult( result = SyncResult(
success=len(errors) == 0, success=len(errors) == 0,
accounts_processed=accounts_processed, accounts_processed=accounts_processed,
@@ -100,12 +106,14 @@ class SyncService:
duration_seconds=duration, duration_seconds=duration,
errors=errors, errors=errors,
started_at=start_time, started_at=start_time,
completed_at=end_time completed_at=end_time,
)
logger.info(
f"Sync completed: {accounts_processed} accounts, {transactions_added} new transactions"
) )
logger.info(f"Sync completed: {accounts_processed} accounts, {transactions_added} new transactions")
return result return result
except Exception as e: except Exception as e:
error_msg = f"Sync failed: {str(e)}" error_msg = f"Sync failed: {str(e)}"
errors.append(error_msg) errors.append(error_msg)
@@ -114,7 +122,9 @@ class SyncService:
finally: finally:
self._sync_status.is_running = False self._sync_status.is_running = False
async def sync_specific_accounts(self, account_ids: List[str], force: bool = False) -> SyncResult: async def sync_specific_accounts(
self, account_ids: List[str], force: bool = False
) -> SyncResult:
"""Sync specific accounts""" """Sync specific accounts"""
if self._sync_status.is_running and not force: if self._sync_status.is_running and not force:
raise Exception("Sync is already running") raise Exception("Sync is already running")
@@ -123,12 +133,12 @@ class SyncService:
# For brevity, implementing a simplified version # For brevity, implementing a simplified version
start_time = datetime.now() start_time = datetime.now()
self._sync_status.is_running = True self._sync_status.is_running = True
try: try:
# Process only specified accounts # Process only specified accounts
# Implementation would be similar to sync_all_accounts # Implementation would be similar to sync_all_accounts
# but filtered to only the specified account_ids # but filtered to only the specified account_ids
end_time = datetime.now() end_time = datetime.now()
return SyncResult( return SyncResult(
success=True, success=True,
@@ -139,7 +149,7 @@ class SyncService:
duration_seconds=(end_time - start_time).total_seconds(), duration_seconds=(end_time - start_time).total_seconds(),
errors=[], errors=[],
started_at=start_time, started_at=start_time,
completed_at=end_time completed_at=end_time,
) )
finally: finally:
self._sync_status.is_running = False self._sync_status.is_running = False

View File

@@ -7,4 +7,4 @@ REQUISITION_STATUS = {
"GA": "GRANTING_ACCESS", "GA": "GRANTING_ACCESS",
"LN": "LINKED", "LN": "LINKED",
"EX": "EXPIRED", "EX": "EXPIRED",
} }

View File

View File

@@ -1,4 +1,5 @@
"""Pytest configuration and shared fixtures.""" """Pytest configuration and shared fixtures."""
import pytest import pytest
import tempfile import tempfile
import json import json
@@ -19,34 +20,27 @@ def temp_config_dir():
yield config_dir yield config_dir
@pytest.fixture @pytest.fixture
def mock_config(temp_config_dir): def mock_config(temp_config_dir):
"""Mock configuration for testing.""" """Mock configuration for testing."""
config_data = { config_data = {
"gocardless": { "gocardless": {
"key": "test-key", "key": "test-key",
"secret": "test-secret", "secret": "test-secret",
"url": "https://bankaccountdata.gocardless.com/api/v2" "url": "https://bankaccountdata.gocardless.com/api/v2",
}, },
"database": { "database": {"sqlite": True},
"sqlite": True "scheduler": {"sync": {"enabled": True, "hour": 3, "minute": 0}},
},
"scheduler": {
"sync": {
"enabled": True,
"hour": 3,
"minute": 0
}
}
} }
config_file = temp_config_dir / "config.toml" config_file = temp_config_dir / "config.toml"
with open(config_file, "wb") as f: with open(config_file, "wb") as f:
import tomli_w import tomli_w
tomli_w.dump(config_data, f) tomli_w.dump(config_data, f)
# Mock the config path # Mock the config path
with patch.object(Config, 'load_config') as mock_load: with patch.object(Config, "load_config") as mock_load:
mock_load.return_value = config_data mock_load.return_value = config_data
config = Config() config = Config()
config._config = config_data config._config = config_data
@@ -57,15 +51,12 @@ def mock_config(temp_config_dir):
@pytest.fixture @pytest.fixture
def mock_auth_token(temp_config_dir): def mock_auth_token(temp_config_dir):
"""Mock GoCardless authentication token.""" """Mock GoCardless authentication token."""
auth_data = { auth_data = {"access": "mock-access-token", "refresh": "mock-refresh-token"}
"access": "mock-access-token",
"refresh": "mock-refresh-token" auth_file = temp_config_dir / "auth.json"
}
auth_file = temp_config_dir / "auth.json"
with open(auth_file, "w") as f: with open(auth_file, "w") as f:
json.dump(auth_data, f) json.dump(auth_data, f)
return auth_data return auth_data
@@ -88,17 +79,17 @@ def sample_bank_data():
{ {
"id": "REVOLUT_REVOLT21", "id": "REVOLUT_REVOLT21",
"name": "Revolut", "name": "Revolut",
"bic": "REVOLT21", "bic": "REVOLT21",
"transaction_total_days": 90, "transaction_total_days": 90,
"countries": ["GB", "LT"] "countries": ["GB", "LT"],
}, },
{ {
"id": "BANCOBPI_BBPIPTPL", "id": "BANCOBPI_BBPIPTPL",
"name": "Banco BPI", "name": "Banco BPI",
"bic": "BBPIPTPL", "bic": "BBPIPTPL",
"transaction_total_days": 90, "transaction_total_days": 90,
"countries": ["PT"] "countries": ["PT"],
} },
] ]
@@ -107,31 +98,28 @@ def sample_account_data():
"""Sample account data for testing.""" """Sample account data for testing."""
return { return {
"id": "test-account-123", "id": "test-account-123",
"institution_id": "REVOLUT_REVOLT21", "institution_id": "REVOLUT_REVOLT21",
"status": "READY", "status": "READY",
"iban": "LT313250081177977789", "iban": "LT313250081177977789",
"created": "2024-02-13T23:56:00Z", "created": "2024-02-13T23:56:00Z",
"last_accessed": "2025-09-01T09:30:00Z" "last_accessed": "2025-09-01T09:30:00Z",
} }
@pytest.fixture @pytest.fixture
def sample_transaction_data(): def sample_transaction_data():
"""Sample transaction data for testing.""" """Sample transaction data for testing."""
return { return {
"transactions": { "transactions": {
"booked": [ "booked": [
{ {
"internalTransactionId": "txn-123", "internalTransactionId": "txn-123",
"bookingDate": "2025-09-01", "bookingDate": "2025-09-01",
"valueDate": "2025-09-01", "valueDate": "2025-09-01",
"transactionAmount": { "transactionAmount": {"amount": "-10.50", "currency": "EUR"},
"amount": "-10.50", "remittanceInformationUnstructured": "Coffee Shop Payment",
"currency": "EUR"
},
"remittanceInformationUnstructured": "Coffee Shop Payment"
} }
], ],
"pending": [] "pending": [],
} }
} }

View File

@@ -1,4 +1,5 @@
"""Tests for accounts API endpoints.""" """Tests for accounts API endpoints."""
import pytest import pytest
import respx import respx
import httpx import httpx
@@ -8,48 +9,47 @@ from unittest.mock import patch
@pytest.mark.api @pytest.mark.api
class TestAccountsAPI: class TestAccountsAPI:
"""Test account-related API endpoints.""" """Test account-related API endpoints."""
@respx.mock @respx.mock
def test_get_all_accounts_success(self, api_client, mock_config, mock_auth_token, sample_account_data): def test_get_all_accounts_success(
self, api_client, mock_config, mock_auth_token, sample_account_data
):
"""Test successful retrieval of all accounts.""" """Test successful retrieval of all accounts."""
requisitions_data = { requisitions_data = {
"results": [ "results": [{"id": "req-123", "accounts": ["test-account-123"]}]
{
"id": "req-123",
"accounts": ["test-account-123"]
}
]
} }
balances_data = { balances_data = {
"balances": [ "balances": [
{ {
"balanceAmount": {"amount": "100.50", "currency": "EUR"}, "balanceAmount": {"amount": "100.50", "currency": "EUR"},
"balanceType": "interimAvailable", "balanceType": "interimAvailable",
"lastChangeDateTime": "2025-09-01T09:30:00Z" "lastChangeDateTime": "2025-09-01T09:30:00Z",
} }
] ]
} }
# Mock GoCardless token creation # Mock GoCardless token creation
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) return_value=httpx.Response(
200, json={"access": "test-token", "refresh": "test-refresh"}
)
) )
# Mock GoCardless API calls # Mock GoCardless API calls
respx.get("https://bankaccountdata.gocardless.com/api/v2/requisitions/").mock( respx.get("https://bankaccountdata.gocardless.com/api/v2/requisitions/").mock(
return_value=httpx.Response(200, json=requisitions_data) return_value=httpx.Response(200, json=requisitions_data)
) )
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/").mock( respx.get(
return_value=httpx.Response(200, json=sample_account_data) "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/"
) ).mock(return_value=httpx.Response(200, json=sample_account_data))
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/").mock( respx.get(
return_value=httpx.Response(200, json=balances_data) "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/"
) ).mock(return_value=httpx.Response(200, json=balances_data))
with patch('leggend.config.config', mock_config): with patch("leggend.config.config", mock_config):
response = api_client.get("/api/v1/accounts") response = api_client.get("/api/v1/accounts")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["success"] is True assert data["success"] is True
@@ -61,33 +61,37 @@ class TestAccountsAPI:
assert account["balances"][0]["amount"] == 100.50 assert account["balances"][0]["amount"] == 100.50
@respx.mock @respx.mock
def test_get_account_details_success(self, api_client, mock_config, mock_auth_token, sample_account_data): def test_get_account_details_success(
self, api_client, mock_config, mock_auth_token, sample_account_data
):
"""Test successful retrieval of specific account details.""" """Test successful retrieval of specific account details."""
balances_data = { balances_data = {
"balances": [ "balances": [
{ {
"balanceAmount": {"amount": "250.75", "currency": "EUR"}, "balanceAmount": {"amount": "250.75", "currency": "EUR"},
"balanceType": "interimAvailable" "balanceType": "interimAvailable",
} }
] ]
} }
# Mock GoCardless token creation # Mock GoCardless token creation
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) return_value=httpx.Response(
200, json={"access": "test-token", "refresh": "test-refresh"}
)
) )
# Mock GoCardless API calls # Mock GoCardless API calls
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/").mock( respx.get(
return_value=httpx.Response(200, json=sample_account_data) "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/"
) ).mock(return_value=httpx.Response(200, json=sample_account_data))
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/").mock( respx.get(
return_value=httpx.Response(200, json=balances_data) "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/"
) ).mock(return_value=httpx.Response(200, json=balances_data))
with patch('leggend.config.config', mock_config): with patch("leggend.config.config", mock_config):
response = api_client.get("/api/v1/accounts/test-account-123") response = api_client.get("/api/v1/accounts/test-account-123")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["success"] is True assert data["success"] is True
@@ -97,35 +101,39 @@ class TestAccountsAPI:
assert len(account["balances"]) == 1 assert len(account["balances"]) == 1
@respx.mock @respx.mock
def test_get_account_balances_success(self, api_client, mock_config, mock_auth_token): def test_get_account_balances_success(
self, api_client, mock_config, mock_auth_token
):
"""Test successful retrieval of account balances.""" """Test successful retrieval of account balances."""
balances_data = { balances_data = {
"balances": [ "balances": [
{ {
"balanceAmount": {"amount": "1000.00", "currency": "EUR"}, "balanceAmount": {"amount": "1000.00", "currency": "EUR"},
"balanceType": "interimAvailable", "balanceType": "interimAvailable",
"lastChangeDateTime": "2025-09-01T10:00:00Z" "lastChangeDateTime": "2025-09-01T10:00:00Z",
}, },
{ {
"balanceAmount": {"amount": "950.00", "currency": "EUR"}, "balanceAmount": {"amount": "950.00", "currency": "EUR"},
"balanceType": "expected" "balanceType": "expected",
} },
] ]
} }
# Mock GoCardless token creation # Mock GoCardless token creation
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) return_value=httpx.Response(
200, json={"access": "test-token", "refresh": "test-refresh"}
)
) )
# Mock GoCardless API # Mock GoCardless API
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/").mock( respx.get(
return_value=httpx.Response(200, json=balances_data) "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/"
) ).mock(return_value=httpx.Response(200, json=balances_data))
with patch('leggend.config.config', mock_config): with patch("leggend.config.config", mock_config):
response = api_client.get("/api/v1/accounts/test-account-123/balances") response = api_client.get("/api/v1/accounts/test-account-123/balances")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["success"] is True assert data["success"] is True
@@ -135,29 +143,40 @@ class TestAccountsAPI:
assert data["data"][0]["balance_type"] == "interimAvailable" assert data["data"][0]["balance_type"] == "interimAvailable"
@respx.mock @respx.mock
def test_get_account_transactions_success(self, api_client, mock_config, mock_auth_token, sample_account_data, sample_transaction_data): def test_get_account_transactions_success(
self,
api_client,
mock_config,
mock_auth_token,
sample_account_data,
sample_transaction_data,
):
"""Test successful retrieval of account transactions.""" """Test successful retrieval of account transactions."""
# Mock GoCardless token creation # Mock GoCardless token creation
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) return_value=httpx.Response(
200, json={"access": "test-token", "refresh": "test-refresh"}
)
) )
# Mock GoCardless API calls # Mock GoCardless API calls
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/").mock( respx.get(
return_value=httpx.Response(200, json=sample_account_data) "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/"
) ).mock(return_value=httpx.Response(200, json=sample_account_data))
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/transactions/").mock( respx.get(
return_value=httpx.Response(200, json=sample_transaction_data) "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/transactions/"
) ).mock(return_value=httpx.Response(200, json=sample_transaction_data))
with patch('leggend.config.config', mock_config): with patch("leggend.config.config", mock_config):
response = api_client.get("/api/v1/accounts/test-account-123/transactions?summary_only=true") response = api_client.get(
"/api/v1/accounts/test-account-123/transactions?summary_only=true"
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["success"] is True assert data["success"] is True
assert len(data["data"]) == 1 assert len(data["data"]) == 1
transaction = data["data"][0] transaction = data["data"][0]
assert transaction["internal_transaction_id"] == "txn-123" assert transaction["internal_transaction_id"] == "txn-123"
assert transaction["amount"] == -10.50 assert transaction["amount"] == -10.50
@@ -165,29 +184,40 @@ class TestAccountsAPI:
assert transaction["description"] == "Coffee Shop Payment" assert transaction["description"] == "Coffee Shop Payment"
@respx.mock @respx.mock
def test_get_account_transactions_full_details(self, api_client, mock_config, mock_auth_token, sample_account_data, sample_transaction_data): def test_get_account_transactions_full_details(
self,
api_client,
mock_config,
mock_auth_token,
sample_account_data,
sample_transaction_data,
):
"""Test retrieval of full transaction details.""" """Test retrieval of full transaction details."""
# Mock GoCardless token creation # Mock GoCardless token creation
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) return_value=httpx.Response(
200, json={"access": "test-token", "refresh": "test-refresh"}
)
) )
# Mock GoCardless API calls # Mock GoCardless API calls
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/").mock( respx.get(
return_value=httpx.Response(200, json=sample_account_data) "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/"
) ).mock(return_value=httpx.Response(200, json=sample_account_data))
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/transactions/").mock( respx.get(
return_value=httpx.Response(200, json=sample_transaction_data) "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/transactions/"
) ).mock(return_value=httpx.Response(200, json=sample_transaction_data))
with patch('leggend.config.config', mock_config): with patch("leggend.config.config", mock_config):
response = api_client.get("/api/v1/accounts/test-account-123/transactions?summary_only=false") response = api_client.get(
"/api/v1/accounts/test-account-123/transactions?summary_only=false"
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["success"] is True assert data["success"] is True
assert len(data["data"]) == 1 assert len(data["data"]) == 1
transaction = data["data"][0] transaction = data["data"][0]
assert transaction["internal_transaction_id"] == "txn-123" assert transaction["internal_transaction_id"] == "txn-123"
assert transaction["institution_id"] == "REVOLUT_REVOLT21" assert transaction["institution_id"] == "REVOLUT_REVOLT21"
@@ -200,14 +230,18 @@ class TestAccountsAPI:
with respx.mock: with respx.mock:
# Mock GoCardless token creation # Mock GoCardless token creation
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) return_value=httpx.Response(
200, json={"access": "test-token", "refresh": "test-refresh"}
)
) )
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/nonexistent/").mock( respx.get(
"https://bankaccountdata.gocardless.com/api/v2/accounts/nonexistent/"
).mock(
return_value=httpx.Response(404, json={"detail": "Account not found"}) return_value=httpx.Response(404, json={"detail": "Account not found"})
) )
with patch('leggend.config.config', mock_config): with patch("leggend.config.config", mock_config):
response = api_client.get("/api/v1/accounts/nonexistent") response = api_client.get("/api/v1/accounts/nonexistent")
assert response.status_code == 404 assert response.status_code == 404

View File

@@ -1,4 +1,5 @@
"""Tests for banks API endpoints.""" """Tests for banks API endpoints."""
import pytest import pytest
import respx import respx
import httpx import httpx
@@ -10,23 +11,27 @@ from leggend.services.gocardless_service import GoCardlessService
@pytest.mark.api @pytest.mark.api
class TestBanksAPI: class TestBanksAPI:
"""Test bank-related API endpoints.""" """Test bank-related API endpoints."""
@respx.mock @respx.mock
def test_get_institutions_success(self, api_client, mock_config, mock_auth_token, sample_bank_data): def test_get_institutions_success(
self, api_client, mock_config, mock_auth_token, sample_bank_data
):
"""Test successful retrieval of bank institutions.""" """Test successful retrieval of bank institutions."""
# Mock GoCardless token creation/refresh # Mock GoCardless token creation/refresh
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) return_value=httpx.Response(
200, json={"access": "test-token", "refresh": "test-refresh"}
)
) )
# Mock GoCardless institutions API # Mock GoCardless institutions API
respx.get("https://bankaccountdata.gocardless.com/api/v2/institutions/").mock( respx.get("https://bankaccountdata.gocardless.com/api/v2/institutions/").mock(
return_value=httpx.Response(200, json=sample_bank_data) return_value=httpx.Response(200, json=sample_bank_data)
) )
with patch('leggend.config.config', mock_config): with patch("leggend.config.config", mock_config):
response = api_client.get("/api/v1/banks/institutions?country=PT") response = api_client.get("/api/v1/banks/institutions?country=PT")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["success"] is True assert data["success"] is True
@@ -39,17 +44,19 @@ class TestBanksAPI:
"""Test institutions endpoint with invalid country code.""" """Test institutions endpoint with invalid country code."""
# Mock GoCardless token creation # Mock GoCardless token creation
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) return_value=httpx.Response(
200, json={"access": "test-token", "refresh": "test-refresh"}
)
) )
# Mock empty institutions response for invalid country # Mock empty institutions response for invalid country
respx.get("https://bankaccountdata.gocardless.com/api/v2/institutions/").mock( respx.get("https://bankaccountdata.gocardless.com/api/v2/institutions/").mock(
return_value=httpx.Response(200, json=[]) return_value=httpx.Response(200, json=[])
) )
with patch('leggend.config.config', mock_config): with patch("leggend.config.config", mock_config):
response = api_client.get("/api/v1/banks/institutions?country=XX") response = api_client.get("/api/v1/banks/institutions?country=XX")
# Should still work but return empty or filtered results # Should still work but return empty or filtered results
assert response.status_code in [200, 404] assert response.status_code in [200, 404]
@@ -61,27 +68,29 @@ class TestBanksAPI:
"institution_id": "REVOLUT_REVOLT21", "institution_id": "REVOLUT_REVOLT21",
"status": "CR", "status": "CR",
"created": "2025-09-02T00:00:00Z", "created": "2025-09-02T00:00:00Z",
"link": "https://example.com/auth" "link": "https://example.com/auth",
} }
# Mock GoCardless token creation # Mock GoCardless token creation
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) return_value=httpx.Response(
200, json={"access": "test-token", "refresh": "test-refresh"}
)
) )
# Mock GoCardless requisitions API # Mock GoCardless requisitions API
respx.post("https://bankaccountdata.gocardless.com/api/v2/requisitions/").mock( respx.post("https://bankaccountdata.gocardless.com/api/v2/requisitions/").mock(
return_value=httpx.Response(200, json=requisition_data) return_value=httpx.Response(200, json=requisition_data)
) )
request_data = { request_data = {
"institution_id": "REVOLUT_REVOLT21", "institution_id": "REVOLUT_REVOLT21",
"redirect_url": "http://localhost:8000/" "redirect_url": "http://localhost:8000/",
} }
with patch('leggend.config.config', mock_config): with patch("leggend.config.config", mock_config):
response = api_client.post("/api/v1/banks/connect", json=request_data) response = api_client.post("/api/v1/banks/connect", json=request_data)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["success"] is True assert data["success"] is True
@@ -95,27 +104,29 @@ class TestBanksAPI:
"results": [ "results": [
{ {
"id": "req-123", "id": "req-123",
"institution_id": "REVOLUT_REVOLT21", "institution_id": "REVOLUT_REVOLT21",
"status": "LN", "status": "LN",
"created": "2025-09-02T00:00:00Z", "created": "2025-09-02T00:00:00Z",
"accounts": ["acc-123"] "accounts": ["acc-123"],
} }
] ]
} }
# Mock GoCardless token creation # Mock GoCardless token creation
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) return_value=httpx.Response(
200, json={"access": "test-token", "refresh": "test-refresh"}
)
) )
# Mock GoCardless requisitions API # Mock GoCardless requisitions API
respx.get("https://bankaccountdata.gocardless.com/api/v2/requisitions/").mock( respx.get("https://bankaccountdata.gocardless.com/api/v2/requisitions/").mock(
return_value=httpx.Response(200, json=requisitions_data) return_value=httpx.Response(200, json=requisitions_data)
) )
with patch('leggend.config.config', mock_config): with patch("leggend.config.config", mock_config):
response = api_client.get("/api/v1/banks/status") response = api_client.get("/api/v1/banks/status")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["success"] is True assert data["success"] is True
@@ -126,12 +137,12 @@ class TestBanksAPI:
def test_get_supported_countries(self, api_client): def test_get_supported_countries(self, api_client):
"""Test supported countries endpoint.""" """Test supported countries endpoint."""
response = api_client.get("/api/v1/banks/countries") response = api_client.get("/api/v1/banks/countries")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["success"] is True assert data["success"] is True
assert len(data["data"]) > 0 assert len(data["data"]) > 0
# Check some expected countries # Check some expected countries
country_codes = [country["code"] for country in data["data"]] country_codes = [country["code"] for country in data["data"]]
assert "PT" in country_codes assert "PT" in country_codes
@@ -145,10 +156,10 @@ class TestBanksAPI:
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
return_value=httpx.Response(401, json={"detail": "Invalid credentials"}) return_value=httpx.Response(401, json={"detail": "Invalid credentials"})
) )
with patch('leggend.config.config', mock_config): with patch("leggend.config.config", mock_config):
response = api_client.get("/api/v1/banks/institutions") response = api_client.get("/api/v1/banks/institutions")
assert response.status_code == 500 assert response.status_code == 500
data = response.json() data = response.json()
assert "Failed to get institutions" in data["detail"] assert "Failed to get institutions" in data["detail"]

View File

@@ -1,4 +1,5 @@
"""Tests for CLI API client.""" """Tests for CLI API client."""
import pytest import pytest
import requests_mock import requests_mock
from unittest.mock import patch from unittest.mock import patch
@@ -13,36 +14,36 @@ class TestLeggendAPIClient:
def test_health_check_success(self): def test_health_check_success(self):
"""Test successful health check.""" """Test successful health check."""
client = LeggendAPIClient("http://localhost:8000") client = LeggendAPIClient("http://localhost:8000")
with requests_mock.Mocker() as m: with requests_mock.Mocker() as m:
m.get("http://localhost:8000/health", json={"status": "healthy"}) m.get("http://localhost:8000/health", json={"status": "healthy"})
result = client.health_check() result = client.health_check()
assert result is True assert result is True
def test_health_check_failure(self): def test_health_check_failure(self):
"""Test health check failure.""" """Test health check failure."""
client = LeggendAPIClient("http://localhost:8000") client = LeggendAPIClient("http://localhost:8000")
with requests_mock.Mocker() as m: with requests_mock.Mocker() as m:
m.get("http://localhost:8000/health", status_code=500) m.get("http://localhost:8000/health", status_code=500)
result = client.health_check() result = client.health_check()
assert result is False assert result is False
def test_get_institutions_success(self, sample_bank_data): def test_get_institutions_success(self, sample_bank_data):
"""Test getting institutions via API client.""" """Test getting institutions via API client."""
client = LeggendAPIClient("http://localhost:8000") client = LeggendAPIClient("http://localhost:8000")
api_response = { api_response = {
"success": True, "success": True,
"data": sample_bank_data, "data": sample_bank_data,
"message": "Found 2 institutions for PT" "message": "Found 2 institutions for PT",
} }
with requests_mock.Mocker() as m: with requests_mock.Mocker() as m:
m.get("http://localhost:8000/api/v1/banks/institutions", json=api_response) m.get("http://localhost:8000/api/v1/banks/institutions", json=api_response)
result = client.get_institutions("PT") result = client.get_institutions("PT")
assert len(result) == 2 assert len(result) == 2
assert result[0]["id"] == "REVOLUT_REVOLT21" assert result[0]["id"] == "REVOLUT_REVOLT21"
@@ -50,16 +51,16 @@ class TestLeggendAPIClient:
def test_get_accounts_success(self, sample_account_data): def test_get_accounts_success(self, sample_account_data):
"""Test getting accounts via API client.""" """Test getting accounts via API client."""
client = LeggendAPIClient("http://localhost:8000") client = LeggendAPIClient("http://localhost:8000")
api_response = { api_response = {
"success": True, "success": True,
"data": [sample_account_data], "data": [sample_account_data],
"message": "Retrieved 1 accounts" "message": "Retrieved 1 accounts",
} }
with requests_mock.Mocker() as m: with requests_mock.Mocker() as m:
m.get("http://localhost:8000/api/v1/accounts", json=api_response) m.get("http://localhost:8000/api/v1/accounts", json=api_response)
result = client.get_accounts() result = client.get_accounts()
assert len(result) == 1 assert len(result) == 1
assert result[0]["id"] == "test-account-123" assert result[0]["id"] == "test-account-123"
@@ -67,34 +68,37 @@ class TestLeggendAPIClient:
def test_trigger_sync_success(self): def test_trigger_sync_success(self):
"""Test triggering sync via API client.""" """Test triggering sync via API client."""
client = LeggendAPIClient("http://localhost:8000") client = LeggendAPIClient("http://localhost:8000")
api_response = { api_response = {
"success": True, "success": True,
"data": {"sync_started": True, "force": False}, "data": {"sync_started": True, "force": False},
"message": "Started sync for all accounts" "message": "Started sync for all accounts",
} }
with requests_mock.Mocker() as m: with requests_mock.Mocker() as m:
m.post("http://localhost:8000/api/v1/sync", json=api_response) m.post("http://localhost:8000/api/v1/sync", json=api_response)
result = client.trigger_sync() result = client.trigger_sync()
assert result["sync_started"] is True assert result["sync_started"] is True
def test_connection_error_handling(self): def test_connection_error_handling(self):
"""Test handling of connection errors.""" """Test handling of connection errors."""
client = LeggendAPIClient("http://localhost:9999") # Non-existent service client = LeggendAPIClient("http://localhost:9999") # Non-existent service
with pytest.raises(Exception): with pytest.raises(Exception):
client.get_accounts() client.get_accounts()
def test_http_error_handling(self): def test_http_error_handling(self):
"""Test handling of HTTP errors.""" """Test handling of HTTP errors."""
client = LeggendAPIClient("http://localhost:8000") client = LeggendAPIClient("http://localhost:8000")
with requests_mock.Mocker() as m: with requests_mock.Mocker() as m:
m.get("http://localhost:8000/api/v1/accounts", status_code=500, m.get(
json={"detail": "Internal server error"}) "http://localhost:8000/api/v1/accounts",
status_code=500,
json={"detail": "Internal server error"},
)
with pytest.raises(Exception): with pytest.raises(Exception):
client.get_accounts() client.get_accounts()
@@ -102,28 +106,28 @@ class TestLeggendAPIClient:
"""Test using custom API URL.""" """Test using custom API URL."""
custom_url = "http://custom-host:9000" custom_url = "http://custom-host:9000"
client = LeggendAPIClient(custom_url) client = LeggendAPIClient(custom_url)
assert client.base_url == custom_url assert client.base_url == custom_url
def test_environment_variable_url(self): def test_environment_variable_url(self):
"""Test using environment variable for API URL.""" """Test using environment variable for API URL."""
with patch.dict('os.environ', {'LEGGEND_API_URL': 'http://env-host:7000'}): with patch.dict("os.environ", {"LEGGEND_API_URL": "http://env-host:7000"}):
client = LeggendAPIClient() client = LeggendAPIClient()
assert client.base_url == "http://env-host:7000" assert client.base_url == "http://env-host:7000"
def test_sync_with_options(self): def test_sync_with_options(self):
"""Test sync with various options.""" """Test sync with various options."""
client = LeggendAPIClient("http://localhost:8000") client = LeggendAPIClient("http://localhost:8000")
api_response = { api_response = {
"success": True, "success": True,
"data": {"sync_started": True, "force": True}, "data": {"sync_started": True, "force": True},
"message": "Started sync for 2 specific accounts" "message": "Started sync for 2 specific accounts",
} }
with requests_mock.Mocker() as m: with requests_mock.Mocker() as m:
m.post("http://localhost:8000/api/v1/sync", json=api_response) m.post("http://localhost:8000/api/v1/sync", json=api_response)
result = client.trigger_sync(account_ids=["acc1", "acc2"], force=True) result = client.trigger_sync(account_ids=["acc1", "acc2"], force=True)
assert result["sync_started"] is True assert result["sync_started"] is True
assert result["force"] is True assert result["force"] is True
@@ -131,20 +135,20 @@ class TestLeggendAPIClient:
def test_get_scheduler_config(self): def test_get_scheduler_config(self):
"""Test getting scheduler configuration.""" """Test getting scheduler configuration."""
client = LeggendAPIClient("http://localhost:8000") client = LeggendAPIClient("http://localhost:8000")
api_response = { api_response = {
"success": True, "success": True,
"data": { "data": {
"enabled": True, "enabled": True,
"hour": 3, "hour": 3,
"minute": 0, "minute": 0,
"next_scheduled_sync": "2025-09-03T03:00:00Z" "next_scheduled_sync": "2025-09-03T03:00:00Z",
} },
} }
with requests_mock.Mocker() as m: with requests_mock.Mocker() as m:
m.get("http://localhost:8000/api/v1/sync/scheduler", json=api_response) m.get("http://localhost:8000/api/v1/sync/scheduler", json=api_response)
result = client.get_scheduler_config() result = client.get_scheduler_config()
assert result["enabled"] is True assert result["enabled"] is True
assert result["hour"] == 3 assert result["hour"] == 3

View File

@@ -1,4 +1,5 @@
"""Tests for configuration management.""" """Tests for configuration management."""
import pytest import pytest
import tempfile import tempfile
from pathlib import Path from pathlib import Path
@@ -23,25 +24,24 @@ class TestConfig:
"gocardless": { "gocardless": {
"key": "test-key", "key": "test-key",
"secret": "test-secret", "secret": "test-secret",
"url": "https://test.example.com" "url": "https://test.example.com",
}, },
"database": { "database": {"sqlite": True},
"sqlite": True
}
} }
config_file = temp_config_dir / "config.toml" config_file = temp_config_dir / "config.toml"
with open(config_file, "wb") as f: with open(config_file, "wb") as f:
import tomli_w import tomli_w
tomli_w.dump(config_data, f) tomli_w.dump(config_data, f)
config = Config() config = Config()
# Reset singleton state for testing # Reset singleton state for testing
config._config = None config._config = None
config._config_path = None config._config_path = None
result = config.load_config(str(config_file)) result = config.load_config(str(config_file))
assert result == config_data assert result == config_data
assert config.gocardless_config["key"] == "test-key" assert config.gocardless_config["key"] == "test-key"
assert config.database_config["sqlite"] is True assert config.database_config["sqlite"] is True
@@ -50,87 +50,84 @@ class TestConfig:
"""Test handling of missing configuration file.""" """Test handling of missing configuration file."""
config = Config() config = Config()
config._config = None # Reset for test config._config = None # Reset for test
with pytest.raises(FileNotFoundError): with pytest.raises(FileNotFoundError):
config.load_config("/nonexistent/config.toml") config.load_config("/nonexistent/config.toml")
def test_save_config_success(self, temp_config_dir): def test_save_config_success(self, temp_config_dir):
"""Test successful configuration saving.""" """Test successful configuration saving."""
config_data = { config_data = {"gocardless": {"key": "new-key", "secret": "new-secret"}}
"gocardless": {
"key": "new-key",
"secret": "new-secret"
}
}
config_file = temp_config_dir / "new_config.toml" config_file = temp_config_dir / "new_config.toml"
config = Config() config = Config()
config._config = None config._config = None
config.save_config(config_data, str(config_file)) config.save_config(config_data, str(config_file))
# Verify file was created and contains correct data # Verify file was created and contains correct data
assert config_file.exists() assert config_file.exists()
import tomllib import tomllib
with open(config_file, "rb") as f: with open(config_file, "rb") as f:
saved_data = tomllib.load(f) saved_data = tomllib.load(f)
assert saved_data == config_data assert saved_data == config_data
def test_update_config_success(self, temp_config_dir): def test_update_config_success(self, temp_config_dir):
"""Test updating configuration values.""" """Test updating configuration values."""
initial_config = { initial_config = {
"gocardless": {"key": "old-key"}, "gocardless": {"key": "old-key"},
"database": {"sqlite": True} "database": {"sqlite": True},
} }
config_file = temp_config_dir / "config.toml" config_file = temp_config_dir / "config.toml"
with open(config_file, "wb") as f: with open(config_file, "wb") as f:
import tomli_w import tomli_w
tomli_w.dump(initial_config, f) tomli_w.dump(initial_config, f)
config = Config() config = Config()
config._config = None config._config = None
config.load_config(str(config_file)) config.load_config(str(config_file))
config.update_config("gocardless", "key", "new-key") config.update_config("gocardless", "key", "new-key")
assert config.gocardless_config["key"] == "new-key" assert config.gocardless_config["key"] == "new-key"
# Verify it was saved to file # Verify it was saved to file
import tomllib import tomllib
with open(config_file, "rb") as f: with open(config_file, "rb") as f:
saved_data = tomllib.load(f) saved_data = tomllib.load(f)
assert saved_data["gocardless"]["key"] == "new-key" assert saved_data["gocardless"]["key"] == "new-key"
def test_update_section_success(self, temp_config_dir): def test_update_section_success(self, temp_config_dir):
"""Test updating entire configuration section.""" """Test updating entire configuration section."""
initial_config = { initial_config = {"database": {"sqlite": True}}
"database": {"sqlite": True}
}
config_file = temp_config_dir / "config.toml" config_file = temp_config_dir / "config.toml"
with open(config_file, "wb") as f: with open(config_file, "wb") as f:
import tomli_w import tomli_w
tomli_w.dump(initial_config, f) tomli_w.dump(initial_config, f)
config = Config() config = Config()
config._config = None config._config = None
config.load_config(str(config_file)) config.load_config(str(config_file))
new_db_config = {"sqlite": False, "path": "./custom.db"} new_db_config = {"sqlite": False, "path": "./custom.db"}
config.update_section("database", new_db_config) config.update_section("database", new_db_config)
assert config.database_config == new_db_config assert config.database_config == new_db_config
def test_scheduler_config_defaults(self): def test_scheduler_config_defaults(self):
"""Test scheduler configuration with defaults.""" """Test scheduler configuration with defaults."""
config = Config() config = Config()
config._config = {} # Empty config config._config = {} # Empty config
scheduler_config = config.scheduler_config scheduler_config = config.scheduler_config
assert scheduler_config["sync"]["enabled"] is True assert scheduler_config["sync"]["enabled"] is True
assert scheduler_config["sync"]["hour"] == 3 assert scheduler_config["sync"]["hour"] == 3
assert scheduler_config["sync"]["minute"] == 0 assert scheduler_config["sync"]["minute"] == 0
@@ -144,16 +141,16 @@ class TestConfig:
"enabled": False, "enabled": False,
"hour": 6, "hour": 6,
"minute": 30, "minute": 30,
"cron": "0 6 * * 1-5" "cron": "0 6 * * 1-5",
} }
} }
} }
config = Config() config = Config()
config._config = custom_config config._config = custom_config
scheduler_config = config.scheduler_config scheduler_config = config.scheduler_config
assert scheduler_config["sync"]["enabled"] is False assert scheduler_config["sync"]["enabled"] is False
assert scheduler_config["sync"]["hour"] == 6 assert scheduler_config["sync"]["hour"] == 6
assert scheduler_config["sync"]["minute"] == 30 assert scheduler_config["sync"]["minute"] == 30
@@ -161,26 +158,28 @@ class TestConfig:
def test_environment_variable_config_path(self): def test_environment_variable_config_path(self):
"""Test using environment variable for config path.""" """Test using environment variable for config path."""
with patch.dict('os.environ', {'LEGGEN_CONFIG_FILE': '/custom/path/config.toml'}): with patch.dict(
"os.environ", {"LEGGEN_CONFIG_FILE": "/custom/path/config.toml"}
):
config = Config() config = Config()
config._config = None config._config = None
with patch('builtins.open', side_effect=FileNotFoundError): with patch("builtins.open", side_effect=FileNotFoundError):
with pytest.raises(FileNotFoundError): with pytest.raises(FileNotFoundError):
config.load_config() config.load_config()
def test_notifications_config(self): def test_notifications_config(self):
"""Test notifications configuration access.""" """Test notifications configuration access."""
custom_config = { custom_config = {
"notifications": { "notifications": {
"discord": {"webhook": "https://discord.webhook", "enabled": True}, "discord": {"webhook": "https://discord.webhook", "enabled": True},
"telegram": {"token": "bot-token", "chat_id": 123} "telegram": {"token": "bot-token", "chat_id": 123},
} }
} }
config = Config() config = Config()
config._config = custom_config config._config = custom_config
notifications = config.notifications_config notifications = config.notifications_config
assert notifications["discord"]["webhook"] == "https://discord.webhook" assert notifications["discord"]["webhook"] == "https://discord.webhook"
assert notifications["telegram"]["token"] == "bot-token" assert notifications["telegram"]["token"] == "bot-token"
@@ -190,13 +189,13 @@ class TestConfig:
custom_config = { custom_config = {
"filters": { "filters": {
"case-insensitive": {"salary": "SALARY", "bills": "BILL"}, "case-insensitive": {"salary": "SALARY", "bills": "BILL"},
"amount_threshold": 100.0 "amount_threshold": 100.0,
} }
} }
config = Config() config = Config()
config._config = custom_config config._config = custom_config
filters = config.filters_config filters = config.filters_config
assert filters["case-insensitive"]["salary"] == "SALARY" assert filters["case-insensitive"]["salary"] == "SALARY"
assert filters["amount_threshold"] == 100.0 assert filters["amount_threshold"] == 100.0

View File

@@ -1,4 +1,5 @@
"""Tests for background scheduler.""" """Tests for background scheduler."""
import pytest import pytest
import asyncio import asyncio
from unittest.mock import Mock, patch, AsyncMock, MagicMock from unittest.mock import Mock, patch, AsyncMock, MagicMock
@@ -16,23 +17,19 @@ class TestBackgroundScheduler:
@pytest.fixture @pytest.fixture
def mock_config(self): def mock_config(self):
"""Mock configuration for scheduler tests.""" """Mock configuration for scheduler tests."""
return { return {"sync": {"enabled": True, "hour": 3, "minute": 0, "cron": None}}
"sync": {
"enabled": True,
"hour": 3,
"minute": 0,
"cron": None
}
}
@pytest.fixture @pytest.fixture
def scheduler(self): def scheduler(self):
"""Create scheduler instance for testing.""" """Create scheduler instance for testing."""
with patch('leggend.background.scheduler.SyncService'), \ with (
patch('leggend.background.scheduler.config') as mock_config: patch("leggend.background.scheduler.SyncService"),
patch("leggend.background.scheduler.config") as mock_config,
mock_config.scheduler_config = {"sync": {"enabled": True, "hour": 3, "minute": 0}} ):
mock_config.scheduler_config = {
"sync": {"enabled": True, "hour": 3, "minute": 0}
}
# Create scheduler and replace its AsyncIO scheduler with a mock # Create scheduler and replace its AsyncIO scheduler with a mock
scheduler = BackgroundScheduler() scheduler = BackgroundScheduler()
mock_scheduler = MagicMock() mock_scheduler = MagicMock()
@@ -43,35 +40,34 @@ class TestBackgroundScheduler:
def test_scheduler_start_default_config(self, scheduler, mock_config): def test_scheduler_start_default_config(self, scheduler, mock_config):
"""Test starting scheduler with default configuration.""" """Test starting scheduler with default configuration."""
with patch('leggend.config.config') as mock_config_obj: with patch("leggend.config.config") as mock_config_obj:
mock_config_obj.scheduler_config = mock_config mock_config_obj.scheduler_config = mock_config
# Mock the job that gets added # Mock the job that gets added
mock_job = MagicMock() mock_job = MagicMock()
mock_job.id = "daily_sync" mock_job.id = "daily_sync"
scheduler.scheduler.get_jobs.return_value = [mock_job] scheduler.scheduler.get_jobs.return_value = [mock_job]
scheduler.start() scheduler.start()
# Verify scheduler.start() was called # Verify scheduler.start() was called
scheduler.scheduler.start.assert_called_once() scheduler.scheduler.start.assert_called_once()
# Verify add_job was called # Verify add_job was called
scheduler.scheduler.add_job.assert_called_once() scheduler.scheduler.add_job.assert_called_once()
def test_scheduler_start_disabled(self, scheduler): def test_scheduler_start_disabled(self, scheduler):
"""Test scheduler behavior when sync is disabled.""" """Test scheduler behavior when sync is disabled."""
disabled_config = { disabled_config = {"sync": {"enabled": False}}
"sync": {"enabled": False}
} with (
patch.object(scheduler, "scheduler") as mock_scheduler,
with patch.object(scheduler, 'scheduler') as mock_scheduler, \ patch("leggend.background.scheduler.config") as mock_config_obj,
patch('leggend.background.scheduler.config') as mock_config_obj: ):
mock_config_obj.scheduler_config = disabled_config mock_config_obj.scheduler_config = disabled_config
mock_scheduler.running = False mock_scheduler.running = False
scheduler.start() scheduler.start()
# Verify scheduler.start() was called # Verify scheduler.start() was called
mock_scheduler.start.assert_called_once() mock_scheduler.start.assert_called_once()
# Verify add_job was NOT called for disabled sync # Verify add_job was NOT called for disabled sync
@@ -82,39 +78,35 @@ class TestBackgroundScheduler:
cron_config = { cron_config = {
"sync": { "sync": {
"enabled": True, "enabled": True,
"cron": "0 6 * * 1-5" # 6 AM on weekdays "cron": "0 6 * * 1-5", # 6 AM on weekdays
} }
} }
with patch('leggend.config.config') as mock_config_obj: with patch("leggend.config.config") as mock_config_obj:
mock_config_obj.scheduler_config = cron_config mock_config_obj.scheduler_config = cron_config
scheduler.start() scheduler.start()
# Verify scheduler.start() and add_job were called # Verify scheduler.start() and add_job were called
scheduler.scheduler.start.assert_called_once() scheduler.scheduler.start.assert_called_once()
scheduler.scheduler.add_job.assert_called_once() scheduler.scheduler.add_job.assert_called_once()
# Verify job was added with correct ID # Verify job was added with correct ID
call_args = scheduler.scheduler.add_job.call_args call_args = scheduler.scheduler.add_job.call_args
assert call_args.kwargs['id'] == 'daily_sync' assert call_args.kwargs["id"] == "daily_sync"
def test_scheduler_start_invalid_cron(self, scheduler): def test_scheduler_start_invalid_cron(self, scheduler):
"""Test handling of invalid cron expressions.""" """Test handling of invalid cron expressions."""
invalid_cron_config = { invalid_cron_config = {"sync": {"enabled": True, "cron": "invalid cron"}}
"sync": {
"enabled": True, with (
"cron": "invalid cron" patch.object(scheduler, "scheduler") as mock_scheduler,
} patch("leggend.background.scheduler.config") as mock_config_obj,
} ):
with patch.object(scheduler, 'scheduler') as mock_scheduler, \
patch('leggend.background.scheduler.config') as mock_config_obj:
mock_config_obj.scheduler_config = invalid_cron_config mock_config_obj.scheduler_config = invalid_cron_config
mock_scheduler.running = False mock_scheduler.running = False
scheduler.start() scheduler.start()
# With invalid cron, scheduler.start() should not be called due to early return # With invalid cron, scheduler.start() should not be called due to early return
# and add_job should not be called # and add_job should not be called
mock_scheduler.start.assert_not_called() mock_scheduler.start.assert_not_called()
@@ -123,24 +115,20 @@ class TestBackgroundScheduler:
def test_scheduler_shutdown(self, scheduler): def test_scheduler_shutdown(self, scheduler):
"""Test scheduler shutdown.""" """Test scheduler shutdown."""
scheduler.scheduler.running = True scheduler.scheduler.running = True
scheduler.shutdown() scheduler.shutdown()
scheduler.scheduler.shutdown.assert_called_once() scheduler.scheduler.shutdown.assert_called_once()
def test_reschedule_sync(self, scheduler, mock_config): def test_reschedule_sync(self, scheduler, mock_config):
"""Test rescheduling sync job.""" """Test rescheduling sync job."""
scheduler.scheduler.running = True scheduler.scheduler.running = True
# Reschedule with new config # Reschedule with new config
new_config = { new_config = {"enabled": True, "hour": 6, "minute": 30}
"enabled": True,
"hour": 6,
"minute": 30
}
scheduler.reschedule_sync(new_config) scheduler.reschedule_sync(new_config)
# Verify remove_job and add_job were called # Verify remove_job and add_job were called
scheduler.scheduler.remove_job.assert_called_once_with("daily_sync") scheduler.scheduler.remove_job.assert_called_once_with("daily_sync")
scheduler.scheduler.add_job.assert_called_once() scheduler.scheduler.add_job.assert_called_once()
@@ -148,11 +136,11 @@ class TestBackgroundScheduler:
def test_reschedule_sync_disable(self, scheduler, mock_config): def test_reschedule_sync_disable(self, scheduler, mock_config):
"""Test disabling sync via reschedule.""" """Test disabling sync via reschedule."""
scheduler.scheduler.running = True scheduler.scheduler.running = True
# Disable sync # Disable sync
disabled_config = {"enabled": False} disabled_config = {"enabled": False}
scheduler.reschedule_sync(disabled_config) scheduler.reschedule_sync(disabled_config)
# Job should be removed but not re-added # Job should be removed but not re-added
scheduler.scheduler.remove_job.assert_called_once_with("daily_sync") scheduler.scheduler.remove_job.assert_called_once_with("daily_sync")
scheduler.scheduler.add_job.assert_not_called() scheduler.scheduler.add_job.assert_not_called()
@@ -162,9 +150,9 @@ class TestBackgroundScheduler:
mock_job = MagicMock() mock_job = MagicMock()
mock_job.next_run_time = datetime(2025, 9, 2, 3, 0) mock_job.next_run_time = datetime(2025, 9, 2, 3, 0)
scheduler.scheduler.get_job.return_value = mock_job scheduler.scheduler.get_job.return_value = mock_job
next_time = scheduler.get_next_sync_time() next_time = scheduler.get_next_sync_time()
assert next_time is not None assert next_time is not None
assert isinstance(next_time, datetime) assert isinstance(next_time, datetime)
scheduler.scheduler.get_job.assert_called_once_with("daily_sync") scheduler.scheduler.get_job.assert_called_once_with("daily_sync")
@@ -172,9 +160,9 @@ class TestBackgroundScheduler:
def test_get_next_sync_time_no_job(self, scheduler): def test_get_next_sync_time_no_job(self, scheduler):
"""Test getting next sync time when no job is scheduled.""" """Test getting next sync time when no job is scheduled."""
scheduler.scheduler.get_job.return_value = None scheduler.scheduler.get_job.return_value = None
next_time = scheduler.get_next_sync_time() next_time = scheduler.get_next_sync_time()
assert next_time is None assert next_time is None
scheduler.scheduler.get_job.assert_called_once_with("daily_sync") scheduler.scheduler.get_job.assert_called_once_with("daily_sync")
@@ -183,9 +171,9 @@ class TestBackgroundScheduler:
"""Test successful sync job execution.""" """Test successful sync job execution."""
mock_sync_service = AsyncMock() mock_sync_service = AsyncMock()
scheduler.sync_service = mock_sync_service scheduler.sync_service = mock_sync_service
await scheduler._run_sync() await scheduler._run_sync()
mock_sync_service.sync_all_accounts.assert_called_once() mock_sync_service.sync_all_accounts.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -194,18 +182,18 @@ class TestBackgroundScheduler:
mock_sync_service = AsyncMock() mock_sync_service = AsyncMock()
mock_sync_service.sync_all_accounts.side_effect = Exception("Sync failed") mock_sync_service.sync_all_accounts.side_effect = Exception("Sync failed")
scheduler.sync_service = mock_sync_service scheduler.sync_service = mock_sync_service
# Should not raise exception, just log error # Should not raise exception, just log error
await scheduler._run_sync() await scheduler._run_sync()
mock_sync_service.sync_all_accounts.assert_called_once() mock_sync_service.sync_all_accounts.assert_called_once()
def test_scheduler_job_max_instances(self, scheduler, mock_config): def test_scheduler_job_max_instances(self, scheduler, mock_config):
"""Test that sync jobs have max_instances=1.""" """Test that sync jobs have max_instances=1."""
with patch('leggend.config.config') as mock_config_obj: with patch("leggend.config.config") as mock_config_obj:
mock_config_obj.scheduler_config = mock_config mock_config_obj.scheduler_config = mock_config
scheduler.start() scheduler.start()
# Verify add_job was called with max_instances=1 # Verify add_job was called with max_instances=1
call_args = scheduler.scheduler.add_job.call_args call_args = scheduler.scheduler.add_job.call_args
assert call_args.kwargs['max_instances'] == 1 assert call_args.kwargs["max_instances"] == 1