chore: Implement code review suggestions and format code.

This commit is contained in:
Elisiário Couto
2025-09-03 21:11:19 +01:00
committed by Elisiário Couto
parent 47164e8546
commit de3da84dff
42 changed files with 1144 additions and 966 deletions

View File

@@ -1,32 +0,0 @@
FROM python:3.13-alpine AS builder
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
WORKDIR /app
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=uv.lock,target=uv.lock \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
uv sync --frozen --no-install-project --no-editable
COPY . /app
RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --frozen --no-editable --no-group dev
FROM python:3.13-alpine
LABEL org.opencontainers.image.source="https://github.com/elisiariocouto/leggen"
LABEL org.opencontainers.image.authors="Elisiário Couto <elisiario@couto.io>"
LABEL org.opencontainers.image.licenses="MIT"
LABEL org.opencontainers.image.title="leggend"
LABEL org.opencontainers.image.description="Leggen API service"
LABEL org.opencontainers.image.url="https://github.com/elisiariocouto/leggen"
WORKDIR /app
ENV PATH="/app/.venv/bin:$PATH"
COPY --from=builder --chown=app:app /app/.venv /app/.venv
EXPOSE 8000
ENTRYPOINT ["/app/.venv/bin/leggend"]

View File

@@ -3,7 +3,6 @@ services:
leggend: leggend:
build: build:
context: . context: .
dockerfile: Dockerfile.leggend
restart: "unless-stopped" restart: "unless-stopped"
ports: ports:
- "127.0.0.1:8000:8000" - "127.0.0.1:8000:8000"
@@ -18,20 +17,6 @@ services:
timeout: 10s timeout: 10s
retries: 3 retries: 3
# CLI for one-off operations (uses leggend API)
leggen:
image: elisiariocouto/leggen:latest
command: sync --wait
restart: "no"
volumes:
- "./leggen:/root/.config/leggen"
- "./db:/app"
environment:
- LEGGEND_API_URL=http://leggend:8000
depends_on:
leggend:
condition: service_healthy
nocodb: nocodb:
image: nocodb/nocodb:latest image: nocodb/nocodb:latest
restart: "unless-stopped" restart: "unless-stopped"

View File

@@ -10,12 +10,13 @@ class LeggendAPIClient:
"""Client for communicating with the leggend FastAPI service""" """Client for communicating with the leggend FastAPI service"""
def __init__(self, base_url: Optional[str] = None): def __init__(self, base_url: Optional[str] = None):
self.base_url = base_url or os.environ.get("LEGGEND_API_URL", "http://localhost:8000") self.base_url = base_url or os.environ.get(
"LEGGEND_API_URL", "http://localhost:8000"
)
self.session = requests.Session() self.session = requests.Session()
self.session.headers.update({ self.session.headers.update(
"Content-Type": "application/json", {"Content-Type": "application/json", "Accept": "application/json"}
"Accept": "application/json" )
})
def _make_request(self, method: str, endpoint: str, **kwargs) -> Dict[str, Any]: def _make_request(self, method: str, endpoint: str, **kwargs) -> Dict[str, Any]:
"""Make HTTP request to the API""" """Make HTTP request to the API"""
@@ -53,15 +54,19 @@ class LeggendAPIClient:
# Bank endpoints # Bank endpoints
def get_institutions(self, country: str = "PT") -> List[Dict[str, Any]]: def get_institutions(self, country: str = "PT") -> List[Dict[str, Any]]:
"""Get bank institutions for a country""" """Get bank institutions for a country"""
response = self._make_request("GET", "/api/v1/banks/institutions", params={"country": country}) response = self._make_request(
"GET", "/api/v1/banks/institutions", params={"country": country}
)
return response.get("data", []) return response.get("data", [])
def connect_to_bank(self, institution_id: str, redirect_url: str = "http://localhost:8000/") -> Dict[str, Any]: def connect_to_bank(
self, institution_id: str, redirect_url: str = "http://localhost:8000/"
) -> Dict[str, Any]:
"""Connect to a bank""" """Connect to a bank"""
response = self._make_request( response = self._make_request(
"POST", "POST",
"/api/v1/banks/connect", "/api/v1/banks/connect",
json={"institution_id": institution_id, "redirect_url": redirect_url} json={"institution_id": institution_id, "redirect_url": redirect_url},
) )
return response.get("data", {}) return response.get("data", {})
@@ -91,17 +96,21 @@ class LeggendAPIClient:
response = self._make_request("GET", f"/api/v1/accounts/{account_id}/balances") response = self._make_request("GET", f"/api/v1/accounts/{account_id}/balances")
return response.get("data", []) return response.get("data", [])
def get_account_transactions(self, account_id: str, limit: int = 100, summary_only: bool = False) -> List[Dict[str, Any]]: def get_account_transactions(
self, account_id: str, limit: int = 100, summary_only: bool = False
) -> List[Dict[str, Any]]:
"""Get account transactions""" """Get account transactions"""
response = self._make_request( response = self._make_request(
"GET", "GET",
f"/api/v1/accounts/{account_id}/transactions", f"/api/v1/accounts/{account_id}/transactions",
params={"limit": limit, "summary_only": summary_only} params={"limit": limit, "summary_only": summary_only},
) )
return response.get("data", []) return response.get("data", [])
# Transaction endpoints # Transaction endpoints
def get_all_transactions(self, limit: int = 100, summary_only: bool = True, **filters) -> List[Dict[str, Any]]: def get_all_transactions(
self, limit: int = 100, summary_only: bool = True, **filters
) -> List[Dict[str, Any]]:
"""Get all transactions with optional filters""" """Get all transactions with optional filters"""
params = {"limit": limit, "summary_only": summary_only} params = {"limit": limit, "summary_only": summary_only}
params.update(filters) params.update(filters)
@@ -109,13 +118,17 @@ class LeggendAPIClient:
response = self._make_request("GET", "/api/v1/transactions", params=params) response = self._make_request("GET", "/api/v1/transactions", params=params)
return response.get("data", []) return response.get("data", [])
def get_transaction_stats(self, days: int = 30, account_id: Optional[str] = None) -> Dict[str, Any]: def get_transaction_stats(
self, days: int = 30, account_id: Optional[str] = None
) -> Dict[str, Any]:
"""Get transaction statistics""" """Get transaction statistics"""
params = {"days": days} params = {"days": days}
if account_id: if account_id:
params["account_id"] = account_id params["account_id"] = account_id
response = self._make_request("GET", "/api/v1/transactions/stats", params=params) response = self._make_request(
"GET", "/api/v1/transactions/stats", params=params
)
return response.get("data", {}) return response.get("data", {})
# Sync endpoints # Sync endpoints
@@ -124,7 +137,9 @@ class LeggendAPIClient:
response = self._make_request("GET", "/api/v1/sync/status") response = self._make_request("GET", "/api/v1/sync/status")
return response.get("data", {}) return response.get("data", {})
def trigger_sync(self, account_ids: Optional[List[str]] = None, force: bool = False) -> Dict[str, Any]: def trigger_sync(
self, account_ids: Optional[List[str]] = None, force: bool = False
) -> Dict[str, Any]:
"""Trigger a sync""" """Trigger a sync"""
data = {"force": force} data = {"force": force}
if account_ids: if account_ids:
@@ -133,7 +148,9 @@ class LeggendAPIClient:
response = self._make_request("POST", "/api/v1/sync", json=data) response = self._make_request("POST", "/api/v1/sync", json=data)
return response.get("data", {}) return response.get("data", {})
def sync_now(self, account_ids: Optional[List[str]] = None, force: bool = False) -> Dict[str, Any]: def sync_now(
self, account_ids: Optional[List[str]] = None, force: bool = False
) -> Dict[str, Any]:
"""Run sync synchronously""" """Run sync synchronously"""
data = {"force": force} data = {"force": force}
if account_ids: if account_ids:
@@ -147,7 +164,13 @@ class LeggendAPIClient:
response = self._make_request("GET", "/api/v1/sync/scheduler") response = self._make_request("GET", "/api/v1/sync/scheduler")
return response.get("data", {}) return response.get("data", {})
def update_scheduler_config(self, enabled: bool = True, hour: int = 3, minute: int = 0, cron: Optional[str] = None) -> Dict[str, Any]: def update_scheduler_config(
self,
enabled: bool = True,
hour: int = 3,
minute: int = 0,
cron: Optional[str] = None,
) -> Dict[str, Any]:
"""Update scheduler configuration""" """Update scheduler configuration"""
data = {"enabled": enabled, "hour": hour, "minute": minute} data = {"enabled": enabled, "hour": hour, "minute": minute}
if cron: if cron:

View File

@@ -15,7 +15,9 @@ def balances(ctx: click.Context):
# Check if leggend service is available # Check if leggend service is available
if not api_client.health_check(): if not api_client.health_check():
click.echo("Error: Cannot connect to leggend service. Please ensure it's running.") click.echo(
"Error: Cannot connect to leggend service. Please ensure it's running."
)
return return
accounts = api_client.get_accounts() accounts = api_client.get_accounts()
@@ -24,11 +26,7 @@ def balances(ctx: click.Context):
for account in accounts: for account in accounts:
for balance in account.get("balances", []): for balance in account.get("balances", []):
amount = round(float(balance["amount"]), 2) amount = round(float(balance["amount"]), 2)
symbol = ( symbol = "" if balance["currency"] == "EUR" else f" {balance['currency']}"
""
if balance["currency"] == "EUR"
else f" {balance['currency']}"
)
amount_str = f"{amount}{symbol}" amount_str = f"{amount}{symbol}"
date = ( date = (
datefmt(balance.get("last_change_date")) datefmt(balance.get("last_change_date"))

View File

@@ -1,36 +0,0 @@
import os
import click
from leggen.main import cli
cmd_folder = os.path.abspath(os.path.dirname(__file__))
class BankGroup(click.Group):
def list_commands(self, ctx):
rv = []
for filename in os.listdir(cmd_folder):
if filename.endswith(".py") and not filename.startswith("__init__"):
if filename == "list_banks.py":
rv.append("list")
else:
rv.append(filename[:-3])
rv.sort()
return rv
def get_command(self, ctx, name):
try:
if name == "list":
name = "list_banks"
mod = __import__(f"leggen.commands.bank.{name}", None, None, [name])
except ImportError:
return
return getattr(mod, name)
@cli.group(cls=BankGroup)
@click.pass_context
def bank(ctx):
"""Manage banks connections"""
return

View File

@@ -16,7 +16,9 @@ def add(ctx):
# Check if leggend service is available # Check if leggend service is available
if not api_client.health_check(): if not api_client.health_check():
click.echo("Error: Cannot connect to leggend service. Please ensure it's running.") click.echo(
"Error: Cannot connect to leggend service. Please ensure it's running."
)
return return
try: try:
@@ -65,11 +67,15 @@ def add(ctx):
save_file(f"req_{result['id']}.json", result) save_file(f"req_{result['id']}.json", result)
success("Bank connection request created successfully!") success("Bank connection request created successfully!")
warning(f"Please open the following URL in your browser to complete the authorization:") warning(
f"Please open the following URL in your browser to complete the authorization:"
)
click.echo(f"\n{result['link']}\n") click.echo(f"\n{result['link']}\n")
info(f"Requisition ID: {result['id']}") info(f"Requisition ID: {result['id']}")
info("After completing the authorization, you can check the connection status with 'leggen status'") info(
"After completing the authorization, you can check the connection status with 'leggen status'"
)
except Exception as e: except Exception as e:
click.echo(f"Error: Failed to connect to bank: {str(e)}") click.echo(f"Error: Failed to connect to bank: {str(e)}")

View File

@@ -15,7 +15,9 @@ def status(ctx: click.Context):
# Check if leggend service is available # Check if leggend service is available
if not api_client.health_check(): if not api_client.health_check():
click.echo("Error: Cannot connect to leggend service. Please ensure it's running.") click.echo(
"Error: Cannot connect to leggend service. Please ensure it's running."
)
return return
# Get bank connection status # Get bank connection status

View File

@@ -6,8 +6,8 @@ from leggen.utils.text import error, info, success
@cli.command() @cli.command()
@click.option('--wait', is_flag=True, help='Wait for sync to complete (synchronous)') @click.option("--wait", is_flag=True, help="Wait for sync to complete (synchronous)")
@click.option('--force', is_flag=True, help='Force sync even if already running') @click.option("--force", is_flag=True, help="Force sync even if already running")
@click.pass_context @click.pass_context
def sync(ctx: click.Context, wait: bool, force: bool): def sync(ctx: click.Context, wait: bool, force: bool):
""" """
@@ -31,17 +31,17 @@ def sync(ctx: click.Context, wait: bool, force: bool):
info(f"Accounts processed: {result.get('accounts_processed', 0)}") info(f"Accounts processed: {result.get('accounts_processed', 0)}")
info(f"Transactions added: {result.get('transactions_added', 0)}") info(f"Transactions added: {result.get('transactions_added', 0)}")
info(f"Balances updated: {result.get('balances_updated', 0)}") info(f"Balances updated: {result.get('balances_updated', 0)}")
if result.get('duration_seconds'): if result.get("duration_seconds"):
info(f"Duration: {result['duration_seconds']:.2f} seconds") info(f"Duration: {result['duration_seconds']:.2f} seconds")
if result.get('errors'): if result.get("errors"):
error(f"Errors encountered: {len(result['errors'])}") error(f"Errors encountered: {len(result['errors'])}")
for err in result['errors']: for err in result["errors"]:
error(f" - {err}") error(f" - {err}")
else: else:
error("Sync failed") error("Sync failed")
if result.get('errors'): if result.get("errors"):
for err in result['errors']: for err in result["errors"]:
error(f" - {err}") error(f" - {err}")
else: else:
# Trigger async sync # Trigger async sync
@@ -50,7 +50,9 @@ def sync(ctx: click.Context, wait: bool, force: bool):
if result.get("sync_started"): if result.get("sync_started"):
success("Sync started successfully in the background") success("Sync started successfully in the background")
info("Use 'leggen sync --wait' to run synchronously or check status with API") info(
"Use 'leggen sync --wait' to run synchronously or check status with API"
)
else: else:
error("Failed to start sync") error("Failed to start sync")

View File

@@ -7,7 +7,9 @@ from leggen.utils.text import datefmt, info, print_table
@cli.command() @cli.command()
@click.option("-a", "--account", type=str, help="Account ID") @click.option("-a", "--account", type=str, help="Account ID")
@click.option("-l", "--limit", type=int, default=50, help="Number of transactions to show") @click.option(
"-l", "--limit", type=int, default=50, help="Number of transactions to show"
)
@click.option("--full", is_flag=True, help="Show full transaction details") @click.option("--full", is_flag=True, help="Show full transaction details")
@click.pass_context @click.pass_context
def transactions(ctx: click.Context, account: str, limit: int, full: bool): def transactions(ctx: click.Context, account: str, limit: int, full: bool):
@@ -22,7 +24,9 @@ def transactions(ctx: click.Context, account: str, limit: int, full: bool):
# Check if leggend service is available # Check if leggend service is available
if not api_client.health_check(): if not api_client.health_check():
click.echo("Error: Cannot connect to leggend service. Please ensure it's running.") click.echo(
"Error: Cannot connect to leggend service. Please ensure it's running."
)
return return
try: try:
@@ -39,9 +43,7 @@ def transactions(ctx: click.Context, account: str, limit: int, full: bool):
else: else:
# Get all transactions # Get all transactions
transactions_data = api_client.get_all_transactions( transactions_data = api_client.get_all_transactions(
limit=limit, limit=limit, summary_only=not full, account_id=account
summary_only=not full,
account_id=account
) )
# Format transactions for display # Format transactions for display
@@ -49,24 +51,32 @@ def transactions(ctx: click.Context, account: str, limit: int, full: bool):
# Full transaction details # Full transaction details
formatted_transactions = [] formatted_transactions = []
for txn in transactions_data: for txn in transactions_data:
formatted_transactions.append({ formatted_transactions.append(
"ID": txn["internal_transaction_id"][:12] + "...", {
"Date": datefmt(txn["transaction_date"]), "ID": txn["internal_transaction_id"][:12] + "...",
"Description": txn["description"][:50] + "..." if len(txn["description"]) > 50 else txn["description"], "Date": datefmt(txn["transaction_date"]),
"Amount": f"{txn['transaction_value']:.2f} {txn['transaction_currency']}", "Description": txn["description"][:50] + "..."
"Status": txn["transaction_status"].upper(), if len(txn["description"]) > 50
"Account": txn["account_id"][:8] + "...", else txn["description"],
}) "Amount": f"{txn['transaction_value']:.2f} {txn['transaction_currency']}",
"Status": txn["transaction_status"].upper(),
"Account": txn["account_id"][:8] + "...",
}
)
else: else:
# Summary view # Summary view
formatted_transactions = [] formatted_transactions = []
for txn in transactions_data: for txn in transactions_data:
formatted_transactions.append({ formatted_transactions.append(
"Date": datefmt(txn["date"]), {
"Description": txn["description"][:60] + "..." if len(txn["description"]) > 60 else txn["description"], "Date": datefmt(txn["date"]),
"Amount": f"{txn['amount']:.2f} {txn['currency']}", "Description": txn["description"][:60] + "..."
"Status": txn["status"].upper(), if len(txn["description"]) > 60
}) else txn["description"],
"Amount": f"{txn['amount']:.2f} {txn['currency']}",
"Status": txn["status"].upper(),
}
)
if formatted_transactions: if formatted_transactions:
print_table(formatted_transactions) print_table(formatted_transactions)

View File

@@ -90,10 +90,10 @@ class Group(click.Group):
@click.option( @click.option(
"--api-url", "--api-url",
type=str, type=str,
default=None, default="http://localhost:8000",
envvar="LEGGEND_API_URL", envvar="LEGGEND_API_URL",
show_envvar=True, show_envvar=True,
help="URL of the leggend API service (default: http://localhost:8000)", help="URL of the leggend API service",
) )
@click.group( @click.group(
cls=Group, cls=Group,

View File

View File

@@ -6,19 +6,19 @@ from pydantic import BaseModel
class AccountBalance(BaseModel): class AccountBalance(BaseModel):
"""Account balance model""" """Account balance model"""
amount: float amount: float
currency: str currency: str
balance_type: str balance_type: str
last_change_date: Optional[datetime] = None last_change_date: Optional[datetime] = None
class Config: class Config:
json_encoders = { json_encoders = {datetime: lambda v: v.isoformat() if v else None}
datetime: lambda v: v.isoformat() if v else None
}
class AccountDetails(BaseModel): class AccountDetails(BaseModel):
"""Account details model""" """Account details model"""
id: str id: str
institution_id: str institution_id: str
status: str status: str
@@ -30,13 +30,12 @@ class AccountDetails(BaseModel):
balances: List[AccountBalance] = [] balances: List[AccountBalance] = []
class Config: class Config:
json_encoders = { json_encoders = {datetime: lambda v: v.isoformat() if v else None}
datetime: lambda v: v.isoformat() if v else None
}
class Transaction(BaseModel): class Transaction(BaseModel):
"""Transaction model""" """Transaction model"""
internal_transaction_id: str internal_transaction_id: str
institution_id: str institution_id: str
iban: Optional[str] = None iban: Optional[str] = None
@@ -49,13 +48,12 @@ class Transaction(BaseModel):
raw_transaction: Dict[str, Any] raw_transaction: Dict[str, Any]
class Config: class Config:
json_encoders = { json_encoders = {datetime: lambda v: v.isoformat()}
datetime: lambda v: v.isoformat()
}
class TransactionSummary(BaseModel): class TransactionSummary(BaseModel):
"""Transaction summary for lists""" """Transaction summary for lists"""
internal_transaction_id: str internal_transaction_id: str
date: datetime date: datetime
description: str description: str
@@ -65,6 +63,4 @@ class TransactionSummary(BaseModel):
account_id: str account_id: str
class Config: class Config:
json_encoders = { json_encoders = {datetime: lambda v: v.isoformat()}
datetime: lambda v: v.isoformat()
}

View File

@@ -6,6 +6,7 @@ from pydantic import BaseModel
class BankInstitution(BaseModel): class BankInstitution(BaseModel):
"""Bank institution model""" """Bank institution model"""
id: str id: str
name: str name: str
bic: Optional[str] = None bic: Optional[str] = None
@@ -16,12 +17,14 @@ class BankInstitution(BaseModel):
class BankConnectionRequest(BaseModel): class BankConnectionRequest(BaseModel):
"""Request to connect to a bank""" """Request to connect to a bank"""
institution_id: str institution_id: str
redirect_url: Optional[str] = "http://localhost:8000/" redirect_url: Optional[str] = "http://localhost:8000/"
class BankRequisition(BaseModel): class BankRequisition(BaseModel):
"""Bank requisition/connection model""" """Bank requisition/connection model"""
id: str id: str
institution_id: str institution_id: str
status: str status: str
@@ -31,13 +34,12 @@ class BankRequisition(BaseModel):
accounts: List[str] = [] accounts: List[str] = []
class Config: class Config:
json_encoders = { json_encoders = {datetime: lambda v: v.isoformat()}
datetime: lambda v: v.isoformat()
}
class BankConnectionStatus(BaseModel): class BankConnectionStatus(BaseModel):
"""Bank connection status response""" """Bank connection status response"""
bank_id: str bank_id: str
bank_name: str bank_name: str
status: str status: str
@@ -47,6 +49,4 @@ class BankConnectionStatus(BaseModel):
accounts_count: int accounts_count: int
class Config: class Config:
json_encoders = { json_encoders = {datetime: lambda v: v.isoformat()}
datetime: lambda v: v.isoformat()
}

View File

@@ -6,6 +6,7 @@ from pydantic import BaseModel
class APIResponse(BaseModel): class APIResponse(BaseModel):
"""Base API response model""" """Base API response model"""
success: bool = True success: bool = True
message: Optional[str] = None message: Optional[str] = None
data: Optional[Any] = None data: Optional[Any] = None
@@ -13,6 +14,7 @@ class APIResponse(BaseModel):
class ErrorResponse(BaseModel): class ErrorResponse(BaseModel):
"""Error response model""" """Error response model"""
success: bool = False success: bool = False
message: str message: str
error_code: Optional[str] = None error_code: Optional[str] = None
@@ -21,6 +23,7 @@ class ErrorResponse(BaseModel):
class PaginatedResponse(BaseModel): class PaginatedResponse(BaseModel):
"""Paginated response model""" """Paginated response model"""
success: bool = True success: bool = True
data: list data: list
pagination: Dict[str, Any] pagination: Dict[str, Any]

View File

@@ -5,12 +5,14 @@ from pydantic import BaseModel
class DiscordConfig(BaseModel): class DiscordConfig(BaseModel):
"""Discord notification configuration""" """Discord notification configuration"""
webhook: str webhook: str
enabled: bool = True enabled: bool = True
class TelegramConfig(BaseModel): class TelegramConfig(BaseModel):
"""Telegram notification configuration""" """Telegram notification configuration"""
token: str token: str
chat_id: int chat_id: int
enabled: bool = True enabled: bool = True
@@ -18,6 +20,7 @@ class TelegramConfig(BaseModel):
class NotificationFilters(BaseModel): class NotificationFilters(BaseModel):
"""Notification filters configuration""" """Notification filters configuration"""
case_insensitive: Dict[str, str] = {} case_insensitive: Dict[str, str] = {}
case_sensitive: Optional[Dict[str, str]] = None case_sensitive: Optional[Dict[str, str]] = None
amount_threshold: Optional[float] = None amount_threshold: Optional[float] = None
@@ -26,6 +29,7 @@ class NotificationFilters(BaseModel):
class NotificationSettings(BaseModel): class NotificationSettings(BaseModel):
"""Complete notification settings""" """Complete notification settings"""
discord: Optional[DiscordConfig] = None discord: Optional[DiscordConfig] = None
telegram: Optional[TelegramConfig] = None telegram: Optional[TelegramConfig] = None
filters: NotificationFilters = NotificationFilters() filters: NotificationFilters = NotificationFilters()
@@ -33,12 +37,14 @@ class NotificationSettings(BaseModel):
class NotificationTest(BaseModel): class NotificationTest(BaseModel):
"""Test notification request""" """Test notification request"""
service: str # "discord" or "telegram" service: str # "discord" or "telegram"
message: str = "Test notification from Leggen" message: str = "Test notification from Leggen"
class NotificationHistory(BaseModel): class NotificationHistory(BaseModel):
"""Notification history entry""" """Notification history entry"""
id: str id: str
service: str service: str
message: str message: str

View File

@@ -6,12 +6,14 @@ from pydantic import BaseModel
class SyncRequest(BaseModel): class SyncRequest(BaseModel):
"""Request to trigger a sync""" """Request to trigger a sync"""
account_ids: Optional[list[str]] = None # If None, sync all accounts account_ids: Optional[list[str]] = None # If None, sync all accounts
force: bool = False # Force sync even if recently synced force: bool = False # Force sync even if recently synced
class SyncStatus(BaseModel): class SyncStatus(BaseModel):
"""Sync operation status""" """Sync operation status"""
is_running: bool is_running: bool
last_sync: Optional[datetime] = None last_sync: Optional[datetime] = None
next_sync: Optional[datetime] = None next_sync: Optional[datetime] = None
@@ -21,13 +23,12 @@ class SyncStatus(BaseModel):
errors: list[str] = [] errors: list[str] = []
class Config: class Config:
json_encoders = { json_encoders = {datetime: lambda v: v.isoformat() if v else None}
datetime: lambda v: v.isoformat() if v else None
}
class SyncResult(BaseModel): class SyncResult(BaseModel):
"""Result of a sync operation""" """Result of a sync operation"""
success: bool success: bool
accounts_processed: int accounts_processed: int
transactions_added: int transactions_added: int
@@ -39,13 +40,12 @@ class SyncResult(BaseModel):
completed_at: datetime completed_at: datetime
class Config: class Config:
json_encoders = { json_encoders = {datetime: lambda v: v.isoformat()}
datetime: lambda v: v.isoformat()
}
class SchedulerConfig(BaseModel): class SchedulerConfig(BaseModel):
"""Scheduler configuration model""" """Scheduler configuration model"""
enabled: bool = True enabled: bool = True
hour: Optional[int] = 3 hour: Optional[int] = 3
minute: Optional[int] = 0 minute: Optional[int] = 0

View File

@@ -3,7 +3,12 @@ from fastapi import APIRouter, HTTPException, Query
from loguru import logger from loguru import logger
from leggend.api.models.common import APIResponse from leggend.api.models.common import APIResponse
from leggend.api.models.accounts import AccountDetails, AccountBalance, Transaction, TransactionSummary from leggend.api.models.accounts import (
AccountDetails,
AccountBalance,
Transaction,
TransactionSummary,
)
from leggend.services.gocardless_service import GoCardlessService from leggend.services.gocardless_service import GoCardlessService
from leggend.services.database_service import DatabaseService from leggend.services.database_service import DatabaseService
@@ -25,40 +30,46 @@ async def get_all_accounts() -> APIResponse:
accounts = [] accounts = []
for account_id in all_accounts: for account_id in all_accounts:
try: try:
account_details = await gocardless_service.get_account_details(account_id) account_details = await gocardless_service.get_account_details(
balances_data = await gocardless_service.get_account_balances(account_id) account_id
)
balances_data = await gocardless_service.get_account_balances(
account_id
)
# Process balances # Process balances
balances = [] balances = []
for balance in balances_data.get("balances", []): for balance in balances_data.get("balances", []):
balance_amount = balance["balanceAmount"] balance_amount = balance["balanceAmount"]
balances.append(AccountBalance( balances.append(
amount=float(balance_amount["amount"]), AccountBalance(
currency=balance_amount["currency"], amount=float(balance_amount["amount"]),
balance_type=balance["balanceType"], currency=balance_amount["currency"],
last_change_date=balance.get("lastChangeDateTime") balance_type=balance["balanceType"],
)) last_change_date=balance.get("lastChangeDateTime"),
)
)
accounts.append(AccountDetails( accounts.append(
id=account_details["id"], AccountDetails(
institution_id=account_details["institution_id"], id=account_details["id"],
status=account_details["status"], institution_id=account_details["institution_id"],
iban=account_details.get("iban"), status=account_details["status"],
name=account_details.get("name"), iban=account_details.get("iban"),
currency=account_details.get("currency"), name=account_details.get("name"),
created=account_details["created"], currency=account_details.get("currency"),
last_accessed=account_details.get("last_accessed"), created=account_details["created"],
balances=balances last_accessed=account_details.get("last_accessed"),
)) balances=balances,
)
)
except Exception as e: except Exception as e:
logger.error(f"Failed to get details for account {account_id}: {e}") logger.error(f"Failed to get details for account {account_id}: {e}")
continue continue
return APIResponse( return APIResponse(
success=True, success=True, data=accounts, message=f"Retrieved {len(accounts)} accounts"
data=accounts,
message=f"Retrieved {len(accounts)} accounts"
) )
except Exception as e: except Exception as e:
@@ -77,12 +88,14 @@ async def get_account_details(account_id: str) -> APIResponse:
balances = [] balances = []
for balance in balances_data.get("balances", []): for balance in balances_data.get("balances", []):
balance_amount = balance["balanceAmount"] balance_amount = balance["balanceAmount"]
balances.append(AccountBalance( balances.append(
amount=float(balance_amount["amount"]), AccountBalance(
currency=balance_amount["currency"], amount=float(balance_amount["amount"]),
balance_type=balance["balanceType"], currency=balance_amount["currency"],
last_change_date=balance.get("lastChangeDateTime") balance_type=balance["balanceType"],
)) last_change_date=balance.get("lastChangeDateTime"),
)
)
account = AccountDetails( account = AccountDetails(
id=account_details["id"], id=account_details["id"],
@@ -93,13 +106,13 @@ async def get_account_details(account_id: str) -> APIResponse:
currency=account_details.get("currency"), currency=account_details.get("currency"),
created=account_details["created"], created=account_details["created"],
last_accessed=account_details.get("last_accessed"), last_accessed=account_details.get("last_accessed"),
balances=balances balances=balances,
) )
return APIResponse( return APIResponse(
success=True, success=True,
data=account, data=account,
message=f"Account details retrieved for {account_id}" message=f"Account details retrieved for {account_id}",
) )
except Exception as e: except Exception as e:
@@ -116,17 +129,19 @@ async def get_account_balances(account_id: str) -> APIResponse:
balances = [] balances = []
for balance in balances_data.get("balances", []): for balance in balances_data.get("balances", []):
balance_amount = balance["balanceAmount"] balance_amount = balance["balanceAmount"]
balances.append(AccountBalance( balances.append(
amount=float(balance_amount["amount"]), AccountBalance(
currency=balance_amount["currency"], amount=float(balance_amount["amount"]),
balance_type=balance["balanceType"], currency=balance_amount["currency"],
last_change_date=balance.get("lastChangeDateTime") balance_type=balance["balanceType"],
)) last_change_date=balance.get("lastChangeDateTime"),
)
)
return APIResponse( return APIResponse(
success=True, success=True,
data=balances, data=balances,
message=f"Retrieved {len(balances)} balances for account {account_id}" message=f"Retrieved {len(balances)} balances for account {account_id}",
) )
except Exception as e: except Exception as e:
@@ -139,12 +154,16 @@ async def get_account_transactions(
account_id: str, account_id: str,
limit: Optional[int] = Query(default=100, le=500), limit: Optional[int] = Query(default=100, le=500),
offset: Optional[int] = Query(default=0, ge=0), offset: Optional[int] = Query(default=0, ge=0),
summary_only: bool = Query(default=False, description="Return transaction summaries only") summary_only: bool = Query(
default=False, description="Return transaction summaries only"
),
) -> APIResponse: ) -> APIResponse:
"""Get transactions for a specific account""" """Get transactions for a specific account"""
try: try:
account_details = await gocardless_service.get_account_details(account_id) account_details = await gocardless_service.get_account_details(account_id)
transactions_data = await gocardless_service.get_account_transactions(account_id) transactions_data = await gocardless_service.get_account_transactions(
account_id
)
# Process transactions # Process transactions
processed_transactions = database_service.process_transactions( processed_transactions = database_service.process_transactions(
@@ -153,7 +172,7 @@ async def get_account_transactions(
# Apply pagination # Apply pagination
total_transactions = len(processed_transactions) total_transactions = len(processed_transactions)
paginated_transactions = processed_transactions[offset:offset + limit] paginated_transactions = processed_transactions[offset : offset + limit]
if summary_only: if summary_only:
# Return simplified transaction summaries # Return simplified transaction summaries
@@ -165,7 +184,7 @@ async def get_account_transactions(
amount=txn["transactionValue"], amount=txn["transactionValue"],
currency=txn["transactionCurrency"], currency=txn["transactionCurrency"],
status=txn["transactionStatus"], status=txn["transactionStatus"],
account_id=txn["accountId"] account_id=txn["accountId"],
) )
for txn in paginated_transactions for txn in paginated_transactions
] ]
@@ -183,7 +202,7 @@ async def get_account_transactions(
transaction_value=txn["transactionValue"], transaction_value=txn["transactionValue"],
transaction_currency=txn["transactionCurrency"], transaction_currency=txn["transactionCurrency"],
transaction_status=txn["transactionStatus"], transaction_status=txn["transactionStatus"],
raw_transaction=txn["rawTransaction"] raw_transaction=txn["rawTransaction"],
) )
for txn in paginated_transactions for txn in paginated_transactions
] ]
@@ -192,9 +211,11 @@ async def get_account_transactions(
return APIResponse( return APIResponse(
success=True, success=True,
data=data, data=data,
message=f"Retrieved {len(data)} transactions (showing {offset + 1}-{offset + len(data)} of {total_transactions})" message=f"Retrieved {len(data)} transactions (showing {offset + 1}-{offset + len(data)} of {total_transactions})",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get transactions for account {account_id}: {e}") logger.error(f"Failed to get transactions for account {account_id}: {e}")
raise HTTPException(status_code=404, detail=f"Failed to get transactions: {str(e)}") raise HTTPException(
status_code=404, detail=f"Failed to get transactions: {str(e)}"
)

View File

@@ -7,7 +7,7 @@ from leggend.api.models.banks import (
BankInstitution, BankInstitution,
BankConnectionRequest, BankConnectionRequest,
BankRequisition, BankRequisition,
BankConnectionStatus BankConnectionStatus,
) )
from leggend.services.gocardless_service import GoCardlessService from leggend.services.gocardless_service import GoCardlessService
from leggend.utils.gocardless import REQUISITION_STATUS from leggend.utils.gocardless import REQUISITION_STATUS
@@ -18,7 +18,7 @@ gocardless_service = GoCardlessService()
@router.get("/banks/institutions", response_model=APIResponse) @router.get("/banks/institutions", response_model=APIResponse)
async def get_bank_institutions( async def get_bank_institutions(
country: str = Query(default="PT", description="Country code (e.g., PT, ES, FR)") country: str = Query(default="PT", description="Country code (e.g., PT, ES, FR)"),
) -> APIResponse: ) -> APIResponse:
"""Get available bank institutions for a country""" """Get available bank institutions for a country"""
try: try:
@@ -31,7 +31,7 @@ async def get_bank_institutions(
bic=inst.get("bic"), bic=inst.get("bic"),
transaction_total_days=inst["transaction_total_days"], transaction_total_days=inst["transaction_total_days"],
countries=inst["countries"], countries=inst["countries"],
logo=inst.get("logo") logo=inst.get("logo"),
) )
for inst in institutions_data for inst in institutions_data
] ]
@@ -39,12 +39,14 @@ async def get_bank_institutions(
return APIResponse( return APIResponse(
success=True, success=True,
data=institutions, data=institutions,
message=f"Found {len(institutions)} institutions for {country}" message=f"Found {len(institutions)} institutions for {country}",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get institutions for {country}: {e}") logger.error(f"Failed to get institutions for {country}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get institutions: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to get institutions: {str(e)}"
)
@router.post("/banks/connect", response_model=APIResponse) @router.post("/banks/connect", response_model=APIResponse)
@@ -52,8 +54,7 @@ async def connect_to_bank(request: BankConnectionRequest) -> APIResponse:
"""Create a connection to a bank (requisition)""" """Create a connection to a bank (requisition)"""
try: try:
requisition_data = await gocardless_service.create_requisition( requisition_data = await gocardless_service.create_requisition(
request.institution_id, request.institution_id, request.redirect_url
request.redirect_url
) )
requisition = BankRequisition( requisition = BankRequisition(
@@ -62,18 +63,20 @@ async def connect_to_bank(request: BankConnectionRequest) -> APIResponse:
status=requisition_data["status"], status=requisition_data["status"],
created=requisition_data["created"], created=requisition_data["created"],
link=requisition_data["link"], link=requisition_data["link"],
accounts=requisition_data.get("accounts", []) accounts=requisition_data.get("accounts", []),
) )
return APIResponse( return APIResponse(
success=True, success=True,
data=requisition, data=requisition,
message=f"Bank connection created. Please visit the link to authorize." message=f"Bank connection created. Please visit the link to authorize.",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to connect to bank {request.institution_id}: {e}") logger.error(f"Failed to connect to bank {request.institution_id}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to connect to bank: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to connect to bank: {str(e)}"
)
@router.get("/banks/status", response_model=APIResponse) @router.get("/banks/status", response_model=APIResponse)
@@ -87,25 +90,31 @@ async def get_bank_connections_status() -> APIResponse:
status = req["status"] status = req["status"]
status_display = REQUISITION_STATUS.get(status, "UNKNOWN") status_display = REQUISITION_STATUS.get(status, "UNKNOWN")
connections.append(BankConnectionStatus( connections.append(
bank_id=req["institution_id"], BankConnectionStatus(
bank_name=req["institution_id"], # Could be enhanced with actual bank names bank_id=req["institution_id"],
status=status, bank_name=req[
status_display=status_display, "institution_id"
created_at=req["created"], ], # Could be enhanced with actual bank names
requisition_id=req["id"], status=status,
accounts_count=len(req.get("accounts", [])) status_display=status_display,
)) created_at=req["created"],
requisition_id=req["id"],
accounts_count=len(req.get("accounts", [])),
)
)
return APIResponse( return APIResponse(
success=True, success=True,
data=connections, data=connections,
message=f"Found {len(connections)} bank connections" message=f"Found {len(connections)} bank connections",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get bank connection status: {e}") logger.error(f"Failed to get bank connection status: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get bank status: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to get bank status: {str(e)}"
)
@router.delete("/banks/connections/{requisition_id}", response_model=APIResponse) @router.delete("/banks/connections/{requisition_id}", response_model=APIResponse)
@@ -116,12 +125,14 @@ async def delete_bank_connection(requisition_id: str) -> APIResponse:
# For now, return success # For now, return success
return APIResponse( return APIResponse(
success=True, success=True,
message=f"Bank connection {requisition_id} deleted successfully" message=f"Bank connection {requisition_id} deleted successfully",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to delete bank connection {requisition_id}: {e}") logger.error(f"Failed to delete bank connection {requisition_id}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to delete connection: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to delete connection: {str(e)}"
)
@router.get("/banks/countries", response_model=APIResponse) @router.get("/banks/countries", response_model=APIResponse)
@@ -164,5 +175,5 @@ async def get_supported_countries() -> APIResponse:
return APIResponse( return APIResponse(
success=True, success=True,
data=countries, data=countries,
message="Supported countries retrieved successfully" message="Supported countries retrieved successfully",
) )

View File

@@ -8,7 +8,7 @@ from leggend.api.models.notifications import (
NotificationTest, NotificationTest,
DiscordConfig, DiscordConfig,
TelegramConfig, TelegramConfig,
NotificationFilters NotificationFilters,
) )
from leggend.services.notification_service import NotificationService from leggend.services.notification_service import NotificationService
from leggend.config import config from leggend.config import config
@@ -31,30 +31,36 @@ async def get_notification_settings() -> APIResponse:
settings = NotificationSettings( settings = NotificationSettings(
discord=DiscordConfig( discord=DiscordConfig(
webhook="***" if discord_config.get("webhook") else "", webhook="***" if discord_config.get("webhook") else "",
enabled=discord_config.get("enabled", True) enabled=discord_config.get("enabled", True),
) if discord_config.get("webhook") else None, )
if discord_config.get("webhook")
else None,
telegram=TelegramConfig( telegram=TelegramConfig(
token="***" if telegram_config.get("token") else "", token="***" if telegram_config.get("token") else "",
chat_id=telegram_config.get("chat_id", 0), chat_id=telegram_config.get("chat_id", 0),
enabled=telegram_config.get("enabled", True) enabled=telegram_config.get("enabled", True),
) if telegram_config.get("token") else None, )
if telegram_config.get("token")
else None,
filters=NotificationFilters( filters=NotificationFilters(
case_insensitive=filters_config.get("case-insensitive", {}), case_insensitive=filters_config.get("case-insensitive", {}),
case_sensitive=filters_config.get("case-sensitive"), case_sensitive=filters_config.get("case-sensitive"),
amount_threshold=filters_config.get("amount_threshold"), amount_threshold=filters_config.get("amount_threshold"),
keywords=filters_config.get("keywords", []) keywords=filters_config.get("keywords", []),
) ),
) )
return APIResponse( return APIResponse(
success=True, success=True,
data=settings, data=settings,
message="Notification settings retrieved successfully" message="Notification settings retrieved successfully",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get notification settings: {e}") logger.error(f"Failed to get notification settings: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get notification settings: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to get notification settings: {str(e)}"
)
@router.put("/notifications/settings", response_model=APIResponse) @router.put("/notifications/settings", response_model=APIResponse)
@@ -67,14 +73,14 @@ async def update_notification_settings(settings: NotificationSettings) -> APIRes
if settings.discord: if settings.discord:
notifications_config["discord"] = { notifications_config["discord"] = {
"webhook": settings.discord.webhook, "webhook": settings.discord.webhook,
"enabled": settings.discord.enabled "enabled": settings.discord.enabled,
} }
if settings.telegram: if settings.telegram:
notifications_config["telegram"] = { notifications_config["telegram"] = {
"token": settings.telegram.token, "token": settings.telegram.token,
"chat_id": settings.telegram.chat_id, "chat_id": settings.telegram.chat_id,
"enabled": settings.telegram.enabled "enabled": settings.telegram.enabled,
} }
# Update filters config # Update filters config
@@ -97,12 +103,14 @@ async def update_notification_settings(settings: NotificationSettings) -> APIRes
return APIResponse( return APIResponse(
success=True, success=True,
data={"updated": True}, data={"updated": True},
message="Notification settings updated successfully" message="Notification settings updated successfully",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to update notification settings: {e}") logger.error(f"Failed to update notification settings: {e}")
raise HTTPException(status_code=500, detail=f"Failed to update notification settings: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to update notification settings: {str(e)}"
)
@router.post("/notifications/test", response_model=APIResponse) @router.post("/notifications/test", response_model=APIResponse)
@@ -110,25 +118,26 @@ async def test_notification(test_request: NotificationTest) -> APIResponse:
"""Send a test notification""" """Send a test notification"""
try: try:
success = await notification_service.send_test_notification( success = await notification_service.send_test_notification(
test_request.service, test_request.service, test_request.message
test_request.message
) )
if success: if success:
return APIResponse( return APIResponse(
success=True, success=True,
data={"sent": True}, data={"sent": True},
message=f"Test notification sent to {test_request.service} successfully" message=f"Test notification sent to {test_request.service} successfully",
) )
else: else:
return APIResponse( return APIResponse(
success=False, success=False,
message=f"Failed to send test notification to {test_request.service}" message=f"Failed to send test notification to {test_request.service}",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to send test notification: {e}") logger.error(f"Failed to send test notification: {e}")
raise HTTPException(status_code=500, detail=f"Failed to send test notification: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to send test notification: {str(e)}"
)
@router.get("/notifications/services", response_model=APIResponse) @router.get("/notifications/services", response_model=APIResponse)
@@ -141,32 +150,36 @@ async def get_notification_services() -> APIResponse:
"discord": { "discord": {
"name": "Discord", "name": "Discord",
"enabled": bool(notifications_config.get("discord", {}).get("webhook")), "enabled": bool(notifications_config.get("discord", {}).get("webhook")),
"configured": bool(notifications_config.get("discord", {}).get("webhook")), "configured": bool(
"active": notifications_config.get("discord", {}).get("enabled", True) notifications_config.get("discord", {}).get("webhook")
),
"active": notifications_config.get("discord", {}).get("enabled", True),
}, },
"telegram": { "telegram": {
"name": "Telegram", "name": "Telegram",
"enabled": bool( "enabled": bool(
notifications_config.get("telegram", {}).get("token") and notifications_config.get("telegram", {}).get("token")
notifications_config.get("telegram", {}).get("chat_id") and notifications_config.get("telegram", {}).get("chat_id")
), ),
"configured": bool( "configured": bool(
notifications_config.get("telegram", {}).get("token") and notifications_config.get("telegram", {}).get("token")
notifications_config.get("telegram", {}).get("chat_id") and notifications_config.get("telegram", {}).get("chat_id")
), ),
"active": notifications_config.get("telegram", {}).get("enabled", True) "active": notifications_config.get("telegram", {}).get("enabled", True),
} },
} }
return APIResponse( return APIResponse(
success=True, success=True,
data=services, data=services,
message="Notification services status retrieved successfully" message="Notification services status retrieved successfully",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get notification services: {e}") logger.error(f"Failed to get notification services: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get notification services: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to get notification services: {str(e)}"
)
@router.delete("/notifications/settings/{service}", response_model=APIResponse) @router.delete("/notifications/settings/{service}", response_model=APIResponse)
@@ -174,7 +187,9 @@ async def delete_notification_service(service: str) -> APIResponse:
"""Delete/disable a notification service""" """Delete/disable a notification service"""
try: try:
if service not in ["discord", "telegram"]: if service not in ["discord", "telegram"]:
raise HTTPException(status_code=400, detail="Service must be 'discord' or 'telegram'") raise HTTPException(
status_code=400, detail="Service must be 'discord' or 'telegram'"
)
notifications_config = config.notifications_config.copy() notifications_config = config.notifications_config.copy()
if service in notifications_config: if service in notifications_config:
@@ -184,9 +199,11 @@ async def delete_notification_service(service: str) -> APIResponse:
return APIResponse( return APIResponse(
success=True, success=True,
data={"deleted": service}, data={"deleted": service},
message=f"{service.capitalize()} notification service deleted successfully" message=f"{service.capitalize()} notification service deleted successfully",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to delete notification service {service}: {e}") logger.error(f"Failed to delete notification service {service}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to delete notification service: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to delete notification service: {str(e)}"
)

View File

@@ -24,20 +24,19 @@ async def get_sync_status() -> APIResponse:
status.next_sync = next_sync_time status.next_sync = next_sync_time
return APIResponse( return APIResponse(
success=True, success=True, data=status, message="Sync status retrieved successfully"
data=status,
message="Sync status retrieved successfully"
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get sync status: {e}") logger.error(f"Failed to get sync status: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get sync status: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to get sync status: {str(e)}"
)
@router.post("/sync", response_model=APIResponse) @router.post("/sync", response_model=APIResponse)
async def trigger_sync( async def trigger_sync(
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks, sync_request: Optional[SyncRequest] = None
sync_request: Optional[SyncRequest] = None
) -> APIResponse: ) -> APIResponse:
"""Trigger a manual sync operation""" """Trigger a manual sync operation"""
try: try:
@@ -46,7 +45,7 @@ async def trigger_sync(
if status.is_running and not (sync_request and sync_request.force): if status.is_running and not (sync_request and sync_request.force):
return APIResponse( return APIResponse(
success=False, success=False,
message="Sync is already running. Use 'force: true' to override." message="Sync is already running. Use 'force: true' to override.",
) )
# Determine what to sync # Determine what to sync
@@ -55,21 +54,26 @@ async def trigger_sync(
background_tasks.add_task( background_tasks.add_task(
sync_service.sync_specific_accounts, sync_service.sync_specific_accounts,
sync_request.account_ids, sync_request.account_ids,
sync_request.force if sync_request else False sync_request.force if sync_request else False,
)
message = (
f"Started sync for {len(sync_request.account_ids)} specific accounts"
) )
message = f"Started sync for {len(sync_request.account_ids)} specific accounts"
else: else:
# Sync all accounts in background # Sync all accounts in background
background_tasks.add_task( background_tasks.add_task(
sync_service.sync_all_accounts, sync_service.sync_all_accounts,
sync_request.force if sync_request else False sync_request.force if sync_request else False,
) )
message = "Started sync for all accounts" message = "Started sync for all accounts"
return APIResponse( return APIResponse(
success=True, success=True,
data={"sync_started": True, "force": sync_request.force if sync_request else False}, data={
message=message "sync_started": True,
"force": sync_request.force if sync_request else False,
},
message=message,
) )
except Exception as e: except Exception as e:
@@ -83,8 +87,7 @@ async def sync_now(sync_request: Optional[SyncRequest] = None) -> APIResponse:
try: try:
if sync_request and sync_request.account_ids: if sync_request and sync_request.account_ids:
result = await sync_service.sync_specific_accounts( result = await sync_service.sync_specific_accounts(
sync_request.account_ids, sync_request.account_ids, sync_request.force
sync_request.force
) )
else: else:
result = await sync_service.sync_all_accounts( result = await sync_service.sync_all_accounts(
@@ -94,7 +97,9 @@ async def sync_now(sync_request: Optional[SyncRequest] = None) -> APIResponse:
return APIResponse( return APIResponse(
success=result.success, success=result.success,
data=result, data=result,
message="Sync completed" if result.success else f"Sync failed with {len(result.errors)} errors" message="Sync completed"
if result.success
else f"Sync failed with {len(result.errors)} errors",
) )
except Exception as e: except Exception as e:
@@ -111,19 +116,25 @@ async def get_scheduler_config() -> APIResponse:
response_data = { response_data = {
**scheduler_config, **scheduler_config,
"next_scheduled_sync": next_sync_time.isoformat() if next_sync_time else None, "next_scheduled_sync": next_sync_time.isoformat()
"is_running": scheduler.scheduler.running if hasattr(scheduler, 'scheduler') else False if next_sync_time
else None,
"is_running": scheduler.scheduler.running
if hasattr(scheduler, "scheduler")
else False,
} }
return APIResponse( return APIResponse(
success=True, success=True,
data=response_data, data=response_data,
message="Scheduler configuration retrieved successfully" message="Scheduler configuration retrieved successfully",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get scheduler config: {e}") logger.error(f"Failed to get scheduler config: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get scheduler config: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to get scheduler config: {str(e)}"
)
@router.put("/sync/scheduler", response_model=APIResponse) @router.put("/sync/scheduler", response_model=APIResponse)
@@ -135,9 +146,13 @@ async def update_scheduler_config(scheduler_config: SchedulerConfig) -> APIRespo
try: try:
cron_parts = scheduler_config.cron.split() cron_parts = scheduler_config.cron.split()
if len(cron_parts) != 5: if len(cron_parts) != 5:
raise ValueError("Cron expression must have 5 parts: minute hour day month day_of_week") raise ValueError(
"Cron expression must have 5 parts: minute hour day month day_of_week"
)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid cron expression: {str(e)}") raise HTTPException(
status_code=400, detail=f"Invalid cron expression: {str(e)}"
)
# Update configuration # Update configuration
schedule_data = scheduler_config.dict(exclude_none=True) schedule_data = scheduler_config.dict(exclude_none=True)
@@ -149,12 +164,14 @@ async def update_scheduler_config(scheduler_config: SchedulerConfig) -> APIRespo
return APIResponse( return APIResponse(
success=True, success=True,
data=schedule_data, data=schedule_data,
message="Scheduler configuration updated successfully" message="Scheduler configuration updated successfully",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to update scheduler config: {e}") logger.error(f"Failed to update scheduler config: {e}")
raise HTTPException(status_code=500, detail=f"Failed to update scheduler config: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to update scheduler config: {str(e)}"
)
@router.post("/sync/scheduler/start", response_model=APIResponse) @router.post("/sync/scheduler/start", response_model=APIResponse)
@@ -163,19 +180,15 @@ async def start_scheduler() -> APIResponse:
try: try:
if not scheduler.scheduler.running: if not scheduler.scheduler.running:
scheduler.start() scheduler.start()
return APIResponse( return APIResponse(success=True, message="Scheduler started successfully")
success=True,
message="Scheduler started successfully"
)
else: else:
return APIResponse( return APIResponse(success=True, message="Scheduler is already running")
success=True,
message="Scheduler is already running"
)
except Exception as e: except Exception as e:
logger.error(f"Failed to start scheduler: {e}") logger.error(f"Failed to start scheduler: {e}")
raise HTTPException(status_code=500, detail=f"Failed to start scheduler: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to start scheduler: {str(e)}"
)
@router.post("/sync/scheduler/stop", response_model=APIResponse) @router.post("/sync/scheduler/stop", response_model=APIResponse)
@@ -184,16 +197,12 @@ async def stop_scheduler() -> APIResponse:
try: try:
if scheduler.scheduler.running: if scheduler.scheduler.running:
scheduler.shutdown() scheduler.shutdown()
return APIResponse( return APIResponse(success=True, message="Scheduler stopped successfully")
success=True,
message="Scheduler stopped successfully"
)
else: else:
return APIResponse( return APIResponse(success=True, message="Scheduler is already stopped")
success=True,
message="Scheduler is already stopped"
)
except Exception as e: except Exception as e:
logger.error(f"Failed to stop scheduler: {e}") logger.error(f"Failed to stop scheduler: {e}")
raise HTTPException(status_code=500, detail=f"Failed to stop scheduler: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to stop scheduler: {str(e)}"
)

View File

@@ -17,13 +17,25 @@ database_service = DatabaseService()
async def get_all_transactions( async def get_all_transactions(
limit: Optional[int] = Query(default=100, le=500), limit: Optional[int] = Query(default=100, le=500),
offset: Optional[int] = Query(default=0, ge=0), offset: Optional[int] = Query(default=0, ge=0),
summary_only: bool = Query(default=True, description="Return transaction summaries only"), summary_only: bool = Query(
date_from: Optional[str] = Query(default=None, description="Filter from date (YYYY-MM-DD)"), default=True, description="Return transaction summaries only"
date_to: Optional[str] = Query(default=None, description="Filter to date (YYYY-MM-DD)"), ),
min_amount: Optional[float] = Query(default=None, description="Minimum transaction amount"), date_from: Optional[str] = Query(
max_amount: Optional[float] = Query(default=None, description="Maximum transaction amount"), default=None, description="Filter from date (YYYY-MM-DD)"
search: Optional[str] = Query(default=None, description="Search in transaction descriptions"), ),
account_id: Optional[str] = Query(default=None, description="Filter by account ID") date_to: Optional[str] = Query(
default=None, description="Filter to date (YYYY-MM-DD)"
),
min_amount: Optional[float] = Query(
default=None, description="Minimum transaction amount"
),
max_amount: Optional[float] = Query(
default=None, description="Maximum transaction amount"
),
search: Optional[str] = Query(
default=None, description="Search in transaction descriptions"
),
account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
) -> APIResponse: ) -> APIResponse:
"""Get all transactions across all accounts with filtering options""" """Get all transactions across all accounts with filtering options"""
try: try:
@@ -46,7 +58,9 @@ async def get_all_transactions(
for acc_id in all_accounts: for acc_id in all_accounts:
try: try:
account_details = await gocardless_service.get_account_details(acc_id) account_details = await gocardless_service.get_account_details(acc_id)
transactions_data = await gocardless_service.get_account_transactions(acc_id) transactions_data = await gocardless_service.get_account_transactions(
acc_id
)
processed_transactions = database_service.process_transactions( processed_transactions = database_service.process_transactions(
acc_id, account_details, transactions_data acc_id, account_details, transactions_data
@@ -64,27 +78,31 @@ async def get_all_transactions(
if date_from: if date_from:
from_date = datetime.fromisoformat(date_from) from_date = datetime.fromisoformat(date_from)
filtered_transactions = [ filtered_transactions = [
txn for txn in filtered_transactions txn
for txn in filtered_transactions
if txn["transactionDate"] >= from_date if txn["transactionDate"] >= from_date
] ]
if date_to: if date_to:
to_date = datetime.fromisoformat(date_to) to_date = datetime.fromisoformat(date_to)
filtered_transactions = [ filtered_transactions = [
txn for txn in filtered_transactions txn
for txn in filtered_transactions
if txn["transactionDate"] <= to_date if txn["transactionDate"] <= to_date
] ]
# Amount filters # Amount filters
if min_amount is not None: if min_amount is not None:
filtered_transactions = [ filtered_transactions = [
txn for txn in filtered_transactions txn
for txn in filtered_transactions
if txn["transactionValue"] >= min_amount if txn["transactionValue"] >= min_amount
] ]
if max_amount is not None: if max_amount is not None:
filtered_transactions = [ filtered_transactions = [
txn for txn in filtered_transactions txn
for txn in filtered_transactions
if txn["transactionValue"] <= max_amount if txn["transactionValue"] <= max_amount
] ]
@@ -92,19 +110,17 @@ async def get_all_transactions(
if search: if search:
search_lower = search.lower() search_lower = search.lower()
filtered_transactions = [ filtered_transactions = [
txn for txn in filtered_transactions txn
for txn in filtered_transactions
if search_lower in txn["description"].lower() if search_lower in txn["description"].lower()
] ]
# Sort by date (newest first) # Sort by date (newest first)
filtered_transactions.sort( filtered_transactions.sort(key=lambda x: x["transactionDate"], reverse=True)
key=lambda x: x["transactionDate"],
reverse=True
)
# Apply pagination # Apply pagination
total_transactions = len(filtered_transactions) total_transactions = len(filtered_transactions)
paginated_transactions = filtered_transactions[offset:offset + limit] paginated_transactions = filtered_transactions[offset : offset + limit]
if summary_only: if summary_only:
# Return simplified transaction summaries # Return simplified transaction summaries
@@ -116,7 +132,7 @@ async def get_all_transactions(
amount=txn["transactionValue"], amount=txn["transactionValue"],
currency=txn["transactionCurrency"], currency=txn["transactionCurrency"],
status=txn["transactionStatus"], status=txn["transactionStatus"],
account_id=txn["accountId"] account_id=txn["accountId"],
) )
for txn in paginated_transactions for txn in paginated_transactions
] ]
@@ -133,7 +149,7 @@ async def get_all_transactions(
transaction_value=txn["transactionValue"], transaction_value=txn["transactionValue"],
transaction_currency=txn["transactionCurrency"], transaction_currency=txn["transactionCurrency"],
transaction_status=txn["transactionStatus"], transaction_status=txn["transactionStatus"],
raw_transaction=txn["rawTransaction"] raw_transaction=txn["rawTransaction"],
) )
for txn in paginated_transactions for txn in paginated_transactions
] ]
@@ -141,18 +157,20 @@ async def get_all_transactions(
return APIResponse( return APIResponse(
success=True, success=True,
data=data, data=data,
message=f"Retrieved {len(data)} transactions (showing {offset + 1}-{offset + len(data)} of {total_transactions})" message=f"Retrieved {len(data)} transactions (showing {offset + 1}-{offset + len(data)} of {total_transactions})",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get transactions: {e}") logger.error(f"Failed to get transactions: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get transactions: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to get transactions: {str(e)}"
)
@router.get("/transactions/stats", response_model=APIResponse) @router.get("/transactions/stats", response_model=APIResponse)
async def get_transaction_stats( async def get_transaction_stats(
days: int = Query(default=30, description="Number of days to include in stats"), days: int = Query(default=30, description="Number of days to include in stats"),
account_id: Optional[str] = Query(default=None, description="Filter by account ID") account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
) -> APIResponse: ) -> APIResponse:
"""Get transaction statistics for the last N days""" """Get transaction statistics for the last N days"""
try: try:
@@ -178,7 +196,9 @@ async def get_transaction_stats(
for acc_id in all_accounts: for acc_id in all_accounts:
try: try:
account_details = await gocardless_service.get_account_details(acc_id) account_details = await gocardless_service.get_account_details(acc_id)
transactions_data = await gocardless_service.get_account_transactions(acc_id) transactions_data = await gocardless_service.get_account_transactions(
acc_id
)
processed_transactions = database_service.process_transactions( processed_transactions = database_service.process_transactions(
acc_id, account_details, transactions_data acc_id, account_details, transactions_data
@@ -191,7 +211,8 @@ async def get_transaction_stats(
# Filter transactions by date range # Filter transactions by date range
recent_transactions = [ recent_transactions = [
txn for txn in all_transactions txn
for txn in all_transactions
if start_date <= txn["transactionDate"] <= end_date if start_date <= txn["transactionDate"] <= end_date
] ]
@@ -210,8 +231,16 @@ async def get_transaction_stats(
net_change = total_income - total_expenses net_change = total_income - total_expenses
# Count by status # Count by status
booked_count = len([txn for txn in recent_transactions if txn["transactionStatus"] == "booked"]) booked_count = len(
pending_count = len([txn for txn in recent_transactions if txn["transactionStatus"] == "pending"]) [txn for txn in recent_transactions if txn["transactionStatus"] == "booked"]
)
pending_count = len(
[
txn
for txn in recent_transactions
if txn["transactionStatus"] == "pending"
]
)
stats = { stats = {
"period_days": days, "period_days": days,
@@ -222,17 +251,23 @@ async def get_transaction_stats(
"total_expenses": round(total_expenses, 2), "total_expenses": round(total_expenses, 2),
"net_change": round(net_change, 2), "net_change": round(net_change, 2),
"average_transaction": round( "average_transaction": round(
sum(txn["transactionValue"] for txn in recent_transactions) / total_transactions, 2 sum(txn["transactionValue"] for txn in recent_transactions)
) if total_transactions > 0 else 0, / total_transactions,
"accounts_included": len(all_accounts) 2,
)
if total_transactions > 0
else 0,
"accounts_included": len(all_accounts),
} }
return APIResponse( return APIResponse(
success=True, success=True,
data=stats, data=stats,
message=f"Transaction statistics for last {days} days" message=f"Transaction statistics for last {days} days",
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get transaction stats: {e}") logger.error(f"Failed to get transaction stats: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get transaction stats: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to get transaction stats: {str(e)}"
)

View File

@@ -4,12 +4,16 @@ from loguru import logger
from leggend.config import config from leggend.config import config
from leggend.services.sync_service import SyncService from leggend.services.sync_service import SyncService
from leggend.services.notification_service import NotificationService
class BackgroundScheduler: class BackgroundScheduler:
def __init__(self): def __init__(self):
self.scheduler = AsyncIOScheduler() self.scheduler = AsyncIOScheduler()
self.sync_service = SyncService() self.sync_service = SyncService()
self.notification_service = NotificationService()
self.max_retries = 3
self.retry_delay = 300 # 5 minutes
def start(self): def start(self):
"""Start the scheduler and configure sync jobs based on configuration""" """Start the scheduler and configure sync jobs based on configuration"""
@@ -20,31 +24,10 @@ class BackgroundScheduler:
self.scheduler.start() self.scheduler.start()
return return
# Use custom cron expression if provided, otherwise use hour/minute # Parse schedule configuration
if schedule_config.get("cron"): trigger = self._parse_cron_config(schedule_config)
# Parse custom cron expression (e.g., "0 3 * * *" for daily at 3 AM) if not trigger:
try: return
cron_parts = schedule_config["cron"].split()
if len(cron_parts) == 5:
minute, hour, day, month, day_of_week = cron_parts
trigger = CronTrigger(
minute=minute,
hour=hour,
day=day if day != "*" else None,
month=month if month != "*" else None,
day_of_week=day_of_week if day_of_week != "*" else None,
)
else:
logger.error(f"Invalid cron expression: {schedule_config['cron']}")
return
except Exception as e:
logger.error(f"Error parsing cron expression: {e}")
return
else:
# Use hour/minute configuration (default: 3:00 AM daily)
hour = schedule_config.get("hour", 3)
minute = schedule_config.get("minute", 0)
trigger = CronTrigger(hour=hour, minute=minute)
self.scheduler.add_job( self.scheduler.add_job(
self._run_sync, self._run_sync,
@@ -76,28 +59,9 @@ class BackgroundScheduler:
return return
# Configure new schedule # Configure new schedule
if schedule_config.get("cron"): trigger = self._parse_cron_config(schedule_config)
try: if not trigger:
cron_parts = schedule_config["cron"].split() return
if len(cron_parts) == 5:
minute, hour, day, month, day_of_week = cron_parts
trigger = CronTrigger(
minute=minute,
hour=hour,
day=day if day != "*" else None,
month=month if month != "*" else None,
day_of_week=day_of_week if day_of_week != "*" else None,
)
else:
logger.error(f"Invalid cron expression: {schedule_config['cron']}")
return
except Exception as e:
logger.error(f"Error parsing cron expression: {e}")
return
else:
hour = schedule_config.get("hour", 3)
minute = schedule_config.get("minute", 0)
trigger = CronTrigger(hour=hour, minute=minute)
self.scheduler.add_job( self.scheduler.add_job(
self._run_sync, self._run_sync,
@@ -108,13 +72,90 @@ class BackgroundScheduler:
) )
logger.info(f"Rescheduled sync job with: {trigger}") logger.info(f"Rescheduled sync job with: {trigger}")
async def _run_sync(self): def _parse_cron_config(self, schedule_config: dict) -> CronTrigger:
"""Parse cron configuration and return CronTrigger"""
if schedule_config.get("cron"):
# Parse custom cron expression (e.g., "0 3 * * *" for daily at 3 AM)
try:
cron_parts = schedule_config["cron"].split()
if len(cron_parts) == 5:
minute, hour, day, month, day_of_week = cron_parts
return CronTrigger(
minute=minute,
hour=hour,
day=day if day != "*" else None,
month=month if month != "*" else None,
day_of_week=day_of_week if day_of_week != "*" else None,
)
else:
logger.error(f"Invalid cron expression: {schedule_config['cron']}")
return None
except Exception as e:
logger.error(f"Error parsing cron expression: {e}")
return None
else:
# Use hour/minute configuration (default: 3:00 AM daily)
hour = schedule_config.get("hour", 3)
minute = schedule_config.get("minute", 0)
return CronTrigger(hour=hour, minute=minute)
async def _run_sync(self, retry_count: int = 0):
"""Run sync with enhanced error handling and retry logic"""
try: try:
logger.info("Starting scheduled sync job") logger.info("Starting scheduled sync job")
await self.sync_service.sync_all_accounts() await self.sync_service.sync_all_accounts()
logger.info("Scheduled sync job completed successfully") logger.info("Scheduled sync job completed successfully")
except Exception as e: except Exception as e:
logger.error(f"Scheduled sync job failed: {e}") logger.error(
f"Scheduled sync job failed (attempt {retry_count + 1}/{self.max_retries}): {e}"
)
# Send notification about the failure
try:
await self.notification_service.send_expiry_notification(
{
"type": "sync_failure",
"error": str(e),
"retry_count": retry_count + 1,
"max_retries": self.max_retries,
}
)
except Exception as notification_error:
logger.error(
f"Failed to send failure notification: {notification_error}"
)
# Implement retry logic for transient failures
if retry_count < self.max_retries - 1:
import datetime
logger.info(f"Retrying sync job in {self.retry_delay} seconds...")
# Schedule a retry
retry_time = datetime.datetime.now() + datetime.timedelta(
seconds=self.retry_delay
)
self.scheduler.add_job(
self._run_sync,
"date",
args=[retry_count + 1],
id=f"sync_retry_{retry_count + 1}",
run_date=retry_time,
)
else:
logger.error("Maximum retries exceeded for sync job")
# Send final failure notification
try:
await self.notification_service.send_expiry_notification(
{
"type": "sync_final_failure",
"error": str(e),
"retry_count": retry_count + 1,
}
)
except Exception as notification_error:
logger.error(
f"Failed to send final failure notification: {notification_error}"
)
def get_next_sync_time(self): def get_next_sync_time(self):
"""Get the next scheduled sync time""" """Get the next scheduled sync time"""

View File

@@ -24,7 +24,7 @@ class Config:
if config_path is None: if config_path is None:
config_path = os.environ.get( config_path = os.environ.get(
"LEGGEN_CONFIG_FILE", "LEGGEN_CONFIG_FILE",
str(Path.home() / ".config" / "leggen" / "config.toml") str(Path.home() / ".config" / "leggen" / "config.toml"),
) )
self._config_path = config_path self._config_path = config_path
@@ -42,7 +42,9 @@ class Config:
return self._config return self._config
def save_config(self, config_data: Dict[str, Any] = None, config_path: str = None) -> None: def save_config(
self, config_data: Dict[str, Any] = None, config_path: str = None
) -> None:
"""Save configuration to TOML file""" """Save configuration to TOML file"""
if config_data is None: if config_data is None:
config_data = self._config config_data = self._config
@@ -50,7 +52,7 @@ class Config:
if config_path is None: if config_path is None:
config_path = self._config_path or os.environ.get( config_path = self._config_path or os.environ.get(
"LEGGEN_CONFIG_FILE", "LEGGEN_CONFIG_FILE",
str(Path.home() / ".config" / "leggen" / "config.toml") str(Path.home() / ".config" / "leggen" / "config.toml"),
) )
# Ensure directory exists # Ensure directory exists
@@ -117,7 +119,7 @@ class Config:
"enabled": True, "enabled": True,
"hour": 3, "hour": 3,
"minute": 0, "minute": 0,
"cron": None # Optional custom cron expression "cron": None, # Optional custom cron expression
} }
} }
return self.config.get("scheduler", default_schedule) return self.config.get("scheduler", default_schedule)

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from importlib import metadata
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
@@ -36,17 +37,26 @@ async def lifespan(app: FastAPI):
def create_app() -> FastAPI: def create_app() -> FastAPI:
# Get version dynamically from package metadata
try:
version = metadata.version("leggen")
except metadata.PackageNotFoundError:
version = "unknown"
app = FastAPI( app = FastAPI(
title="Leggend API", title="Leggend API",
description="Open Banking API for Leggen", description="Open Banking API for Leggen",
version="0.6.11", version=version,
lifespan=lifespan, lifespan=lifespan,
) )
# Add CORS middleware # Add CORS middleware
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["http://localhost:3000", "http://localhost:5173"], # SvelteKit dev servers allow_origins=[
"http://localhost:3000",
"http://localhost:5173",
], # SvelteKit dev servers
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
@@ -60,7 +70,12 @@ def create_app() -> FastAPI:
@app.get("/") @app.get("/")
async def root(): async def root():
return {"message": "Leggend API is running", "version": "0.6.11"} # Get version dynamically
try:
version = metadata.version("leggen")
except metadata.PackageNotFoundError:
version = "unknown"
return {"message": "Leggend API is running", "version": version}
@app.get("/health") @app.get("/health")
async def health(): async def health():
@@ -71,22 +86,16 @@ def create_app() -> FastAPI:
def main(): def main():
import argparse import argparse
parser = argparse.ArgumentParser(description="Start the Leggend API service") parser = argparse.ArgumentParser(description="Start the Leggend API service")
parser.add_argument( parser.add_argument(
"--reload", "--reload", action="store_true", help="Enable auto-reload for development"
action="store_true",
help="Enable auto-reload for development"
) )
parser.add_argument( parser.add_argument(
"--host", "--host", default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)"
default="0.0.0.0",
help="Host to bind to (default: 0.0.0.0)"
) )
parser.add_argument( parser.add_argument(
"--port", "--port", type=int, default=8000, help="Port to bind to (default: 8000)"
type=int,
default=8000,
help="Port to bind to (default: 8000)"
) )
args = parser.parse_args() args = parser.parse_args()

View File

@@ -11,7 +11,9 @@ class DatabaseService:
self.db_config = config.database_config self.db_config = config.database_config
self.sqlite_enabled = self.db_config.get("sqlite", True) self.sqlite_enabled = self.db_config.get("sqlite", True)
async def persist_balance(self, account_id: str, balance_data: Dict[str, Any]) -> None: async def persist_balance(
self, account_id: str, balance_data: Dict[str, Any]
) -> None:
"""Persist account balance data""" """Persist account balance data"""
if not self.sqlite_enabled: if not self.sqlite_enabled:
logger.warning("SQLite database disabled, skipping balance persistence") logger.warning("SQLite database disabled, skipping balance persistence")
@@ -19,7 +21,9 @@ class DatabaseService:
await self._persist_balance_sqlite(account_id, balance_data) await self._persist_balance_sqlite(account_id, balance_data)
async def persist_transactions(self, account_id: str, transactions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: async def persist_transactions(
self, account_id: str, transactions: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Persist transactions and return new transactions""" """Persist transactions and return new transactions"""
if not self.sqlite_enabled: if not self.sqlite_enabled:
logger.warning("SQLite database disabled, skipping transaction persistence") logger.warning("SQLite database disabled, skipping transaction persistence")
@@ -27,32 +31,48 @@ class DatabaseService:
return await self._persist_transactions_sqlite(account_id, transactions) return await self._persist_transactions_sqlite(account_id, transactions)
def process_transactions(self, account_id: str, account_info: Dict[str, Any], transaction_data: Dict[str, Any]) -> List[Dict[str, Any]]: def process_transactions(
self,
account_id: str,
account_info: Dict[str, Any],
transaction_data: Dict[str, Any],
) -> List[Dict[str, Any]]:
"""Process raw transaction data into standardized format""" """Process raw transaction data into standardized format"""
transactions = [] transactions = []
# Process booked transactions # Process booked transactions
for transaction in transaction_data.get("transactions", {}).get("booked", []): for transaction in transaction_data.get("transactions", {}).get("booked", []):
processed = self._process_single_transaction(account_id, account_info, transaction, "booked") processed = self._process_single_transaction(
account_id, account_info, transaction, "booked"
)
transactions.append(processed) transactions.append(processed)
# Process pending transactions # Process pending transactions
for transaction in transaction_data.get("transactions", {}).get("pending", []): for transaction in transaction_data.get("transactions", {}).get("pending", []):
processed = self._process_single_transaction(account_id, account_info, transaction, "pending") processed = self._process_single_transaction(
account_id, account_info, transaction, "pending"
)
transactions.append(processed) transactions.append(processed)
return transactions return transactions
def _process_single_transaction(self, account_id: str, account_info: Dict[str, Any], transaction: Dict[str, Any], status: str) -> Dict[str, Any]: def _process_single_transaction(
self,
account_id: str,
account_info: Dict[str, Any],
transaction: Dict[str, Any],
status: str,
) -> Dict[str, Any]:
"""Process a single transaction into standardized format""" """Process a single transaction into standardized format"""
# Extract dates # Extract dates
booked_date = transaction.get("bookingDateTime") or transaction.get("bookingDate") booked_date = transaction.get("bookingDateTime") or transaction.get(
"bookingDate"
)
value_date = transaction.get("valueDateTime") or transaction.get("valueDate") value_date = transaction.get("valueDateTime") or transaction.get("valueDate")
if booked_date and value_date: if booked_date and value_date:
min_date = min( min_date = min(
datetime.fromisoformat(booked_date), datetime.fromisoformat(booked_date), datetime.fromisoformat(value_date)
datetime.fromisoformat(value_date)
) )
else: else:
min_date = datetime.fromisoformat(booked_date or value_date) min_date = datetime.fromisoformat(booked_date or value_date)
@@ -65,7 +85,7 @@ class DatabaseService:
# Extract description # Extract description
description = transaction.get( description = transaction.get(
"remittanceInformationUnstructured", "remittanceInformationUnstructured",
",".join(transaction.get("remittanceInformationUnstructuredArray", [])) ",".join(transaction.get("remittanceInformationUnstructuredArray", [])),
) )
return { return {
@@ -81,13 +101,19 @@ class DatabaseService:
"rawTransaction": transaction, "rawTransaction": transaction,
} }
async def _persist_balance_sqlite(self, account_id: str, balance_data: Dict[str, Any]) -> None: async def _persist_balance_sqlite(
self, account_id: str, balance_data: Dict[str, Any]
) -> None:
"""Persist balance to SQLite - placeholder implementation""" """Persist balance to SQLite - placeholder implementation"""
# Would import and use leggen.database.sqlite # Would import and use leggen.database.sqlite
logger.info(f"Persisting balance to SQLite for account {account_id}") logger.info(f"Persisting balance to SQLite for account {account_id}")
async def _persist_transactions_sqlite(self, account_id: str, transactions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: async def _persist_transactions_sqlite(
self, account_id: str, transactions: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Persist transactions to SQLite - placeholder implementation""" """Persist transactions to SQLite - placeholder implementation"""
# Would import and use leggen.database.sqlite # Would import and use leggen.database.sqlite
logger.info(f"Persisting {len(transactions)} transactions to SQLite for account {account_id}") logger.info(
f"Persisting {len(transactions)} transactions to SQLite for account {account_id}"
)
return transactions # Return new transactions for notifications return transactions # Return new transactions for notifications

View File

@@ -12,16 +12,15 @@ from leggend.config import config
class GoCardlessService: class GoCardlessService:
def __init__(self): def __init__(self):
self.config = config.gocardless_config self.config = config.gocardless_config
self.base_url = self.config.get("url", "https://bankaccountdata.gocardless.com/api/v2") self.base_url = self.config.get(
"url", "https://bankaccountdata.gocardless.com/api/v2"
)
self._token = None self._token = None
async def _get_auth_headers(self) -> Dict[str, str]: async def _get_auth_headers(self) -> Dict[str, str]:
"""Get authentication headers for GoCardless API""" """Get authentication headers for GoCardless API"""
token = await self._get_token() token = await self._get_token()
return { return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
}
async def _get_token(self) -> str: async def _get_token(self) -> str:
"""Get access token for GoCardless API""" """Get access token for GoCardless API"""
@@ -42,7 +41,7 @@ class GoCardlessService:
try: try:
response = await client.post( response = await client.post(
f"{self.base_url}/token/refresh/", f"{self.base_url}/token/refresh/",
json={"refresh": auth["refresh"]} json={"refresh": auth["refresh"]},
) )
response.raise_for_status() response.raise_for_status()
auth.update(response.json()) auth.update(response.json())
@@ -95,22 +94,21 @@ class GoCardlessService:
response = await client.get( response = await client.get(
f"{self.base_url}/institutions/", f"{self.base_url}/institutions/",
headers=headers, headers=headers,
params={"country": country} params={"country": country},
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
async def create_requisition(self, institution_id: str, redirect_url: str) -> Dict[str, Any]: async def create_requisition(
self, institution_id: str, redirect_url: str
) -> Dict[str, Any]:
"""Create a bank connection requisition""" """Create a bank connection requisition"""
headers = await self._get_auth_headers() headers = await self._get_auth_headers()
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.base_url}/requisitions/", f"{self.base_url}/requisitions/",
headers=headers, headers=headers,
json={ json={"institution_id": institution_id, "redirect": redirect_url},
"institution_id": institution_id,
"redirect": redirect_url
}
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
@@ -120,8 +118,7 @@ class GoCardlessService:
headers = await self._get_auth_headers() headers = await self._get_auth_headers()
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get( response = await client.get(
f"{self.base_url}/requisitions/", f"{self.base_url}/requisitions/", headers=headers
headers=headers
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
@@ -131,8 +128,7 @@ class GoCardlessService:
headers = await self._get_auth_headers() headers = await self._get_auth_headers()
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get( response = await client.get(
f"{self.base_url}/accounts/{account_id}/", f"{self.base_url}/accounts/{account_id}/", headers=headers
headers=headers
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
@@ -142,8 +138,7 @@ class GoCardlessService:
headers = await self._get_auth_headers() headers = await self._get_auth_headers()
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get( response = await client.get(
f"{self.base_url}/accounts/{account_id}/balances/", f"{self.base_url}/accounts/{account_id}/balances/", headers=headers
headers=headers
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
@@ -153,8 +148,7 @@ class GoCardlessService:
headers = await self._get_auth_headers() headers = await self._get_auth_headers()
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get( response = await client.get(
f"{self.base_url}/accounts/{account_id}/transactions/", f"{self.base_url}/accounts/{account_id}/transactions/", headers=headers
headers=headers
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()

View File

@@ -10,7 +10,9 @@ class NotificationService:
self.notifications_config = config.notifications_config self.notifications_config = config.notifications_config
self.filters_config = config.filters_config self.filters_config = config.filters_config
async def send_transaction_notifications(self, transactions: List[Dict[str, Any]]) -> None: async def send_transaction_notifications(
self, transactions: List[Dict[str, Any]]
) -> None:
"""Send notifications for new transactions that match filters""" """Send notifications for new transactions that match filters"""
if not self.filters_config: if not self.filters_config:
logger.info("No notification filters configured, skipping notifications") logger.info("No notification filters configured, skipping notifications")
@@ -40,7 +42,9 @@ class NotificationService:
await self._send_telegram_test(message) await self._send_telegram_test(message)
return True return True
else: else:
logger.error(f"Notification service '{service}' not enabled or not found") logger.error(
f"Notification service '{service}' not enabled or not found"
)
return False return False
except Exception as e: except Exception as e:
logger.error(f"Failed to send test notification to {service}: {e}") logger.error(f"Failed to send test notification to {service}: {e}")
@@ -54,7 +58,9 @@ class NotificationService:
if self._is_telegram_enabled(): if self._is_telegram_enabled():
await self._send_telegram_expiry(notification_data) await self._send_telegram_expiry(notification_data)
def _filter_transactions(self, transactions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def _filter_transactions(
self, transactions: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Filter transactions based on notification criteria""" """Filter transactions based on notification criteria"""
matching = [] matching = []
filters_case_insensitive = self.filters_config.get("case-insensitive", {}) filters_case_insensitive = self.filters_config.get("case-insensitive", {})
@@ -65,12 +71,14 @@ class NotificationService:
# Check case-insensitive filters # Check case-insensitive filters
for filter_name, filter_value in filters_case_insensitive.items(): for filter_name, filter_value in filters_case_insensitive.items():
if filter_value.lower() in description: if filter_value.lower() in description:
matching.append({ matching.append(
"name": transaction["description"], {
"value": transaction["transactionValue"], "name": transaction["description"],
"currency": transaction["transactionCurrency"], "value": transaction["transactionValue"],
"date": transaction["transactionDate"], "currency": transaction["transactionCurrency"],
}) "date": transaction["transactionDate"],
}
)
break break
return matching return matching
@@ -78,26 +86,34 @@ class NotificationService:
def _is_discord_enabled(self) -> bool: def _is_discord_enabled(self) -> bool:
"""Check if Discord notifications are enabled""" """Check if Discord notifications are enabled"""
discord_config = self.notifications_config.get("discord", {}) discord_config = self.notifications_config.get("discord", {})
return bool(discord_config.get("webhook") and discord_config.get("enabled", True)) return bool(
discord_config.get("webhook") and discord_config.get("enabled", True)
)
def _is_telegram_enabled(self) -> bool: def _is_telegram_enabled(self) -> bool:
"""Check if Telegram notifications are enabled""" """Check if Telegram notifications are enabled"""
telegram_config = self.notifications_config.get("telegram", {}) telegram_config = self.notifications_config.get("telegram", {})
return bool( return bool(
telegram_config.get("token") and telegram_config.get("token")
telegram_config.get("chat_id") and and telegram_config.get("chat_id")
telegram_config.get("enabled", True) and telegram_config.get("enabled", True)
) )
async def _send_discord_notifications(self, transactions: List[Dict[str, Any]]) -> None: async def _send_discord_notifications(
self, transactions: List[Dict[str, Any]]
) -> None:
"""Send Discord notifications - placeholder implementation""" """Send Discord notifications - placeholder implementation"""
# Would import and use leggen.notifications.discord # Would import and use leggen.notifications.discord
logger.info(f"Sending {len(transactions)} transaction notifications to Discord") logger.info(f"Sending {len(transactions)} transaction notifications to Discord")
async def _send_telegram_notifications(self, transactions: List[Dict[str, Any]]) -> None: async def _send_telegram_notifications(
self, transactions: List[Dict[str, Any]]
) -> None:
"""Send Telegram notifications - placeholder implementation""" """Send Telegram notifications - placeholder implementation"""
# Would import and use leggen.notifications.telegram # Would import and use leggen.notifications.telegram
logger.info(f"Sending {len(transactions)} transaction notifications to Telegram") logger.info(
f"Sending {len(transactions)} transaction notifications to Telegram"
)
async def _send_discord_test(self, message: str) -> None: async def _send_discord_test(self, message: str) -> None:
"""Send Discord test notification""" """Send Discord test notification"""

View File

@@ -53,7 +53,9 @@ class SyncService:
for account_id in all_accounts: for account_id in all_accounts:
try: try:
# Get account details # Get account details
account_details = await self.gocardless.get_account_details(account_id) account_details = await self.gocardless.get_account_details(
account_id
)
# Get and save balances # Get and save balances
balances = await self.gocardless.get_account_balances(account_id) balances = await self.gocardless.get_account_balances(account_id)
@@ -62,7 +64,9 @@ class SyncService:
balances_updated += len(balances.get("balances", [])) balances_updated += len(balances.get("balances", []))
# Get and save transactions # Get and save transactions
transactions = await self.gocardless.get_account_transactions(account_id) transactions = await self.gocardless.get_account_transactions(
account_id
)
if transactions: if transactions:
processed_transactions = self.database.process_transactions( processed_transactions = self.database.process_transactions(
account_id, account_details, transactions account_id, account_details, transactions
@@ -74,7 +78,9 @@ class SyncService:
# Send notifications for new transactions # Send notifications for new transactions
if new_transactions: if new_transactions:
await self.notifications.send_transaction_notifications(new_transactions) await self.notifications.send_transaction_notifications(
new_transactions
)
accounts_processed += 1 accounts_processed += 1
self._sync_status.accounts_synced = accounts_processed self._sync_status.accounts_synced = accounts_processed
@@ -100,10 +106,12 @@ class SyncService:
duration_seconds=duration, duration_seconds=duration,
errors=errors, errors=errors,
started_at=start_time, started_at=start_time,
completed_at=end_time completed_at=end_time,
) )
logger.info(f"Sync completed: {accounts_processed} accounts, {transactions_added} new transactions") logger.info(
f"Sync completed: {accounts_processed} accounts, {transactions_added} new transactions"
)
return result return result
except Exception as e: except Exception as e:
@@ -114,7 +122,9 @@ class SyncService:
finally: finally:
self._sync_status.is_running = False self._sync_status.is_running = False
async def sync_specific_accounts(self, account_ids: List[str], force: bool = False) -> SyncResult: async def sync_specific_accounts(
self, account_ids: List[str], force: bool = False
) -> SyncResult:
"""Sync specific accounts""" """Sync specific accounts"""
if self._sync_status.is_running and not force: if self._sync_status.is_running and not force:
raise Exception("Sync is already running") raise Exception("Sync is already running")
@@ -139,7 +149,7 @@ class SyncService:
duration_seconds=(end_time - start_time).total_seconds(), duration_seconds=(end_time - start_time).total_seconds(),
errors=[], errors=[],
started_at=start_time, started_at=start_time,
completed_at=end_time completed_at=end_time,
) )
finally: finally:
self._sync_status.is_running = False self._sync_status.is_running = False

View File

View File

@@ -1,4 +1,5 @@
"""Pytest configuration and shared fixtures.""" """Pytest configuration and shared fixtures."""
import pytest import pytest
import tempfile import tempfile
import json import json
@@ -26,27 +27,20 @@ def mock_config(temp_config_dir):
"gocardless": { "gocardless": {
"key": "test-key", "key": "test-key",
"secret": "test-secret", "secret": "test-secret",
"url": "https://bankaccountdata.gocardless.com/api/v2" "url": "https://bankaccountdata.gocardless.com/api/v2",
}, },
"database": { "database": {"sqlite": True},
"sqlite": True "scheduler": {"sync": {"enabled": True, "hour": 3, "minute": 0}},
},
"scheduler": {
"sync": {
"enabled": True,
"hour": 3,
"minute": 0
}
}
} }
config_file = temp_config_dir / "config.toml" config_file = temp_config_dir / "config.toml"
with open(config_file, "wb") as f: with open(config_file, "wb") as f:
import tomli_w import tomli_w
tomli_w.dump(config_data, f) tomli_w.dump(config_data, f)
# Mock the config path # Mock the config path
with patch.object(Config, 'load_config') as mock_load: with patch.object(Config, "load_config") as mock_load:
mock_load.return_value = config_data mock_load.return_value = config_data
config = Config() config = Config()
config._config = config_data config._config = config_data
@@ -57,10 +51,7 @@ def mock_config(temp_config_dir):
@pytest.fixture @pytest.fixture
def mock_auth_token(temp_config_dir): def mock_auth_token(temp_config_dir):
"""Mock GoCardless authentication token.""" """Mock GoCardless authentication token."""
auth_data = { auth_data = {"access": "mock-access-token", "refresh": "mock-refresh-token"}
"access": "mock-access-token",
"refresh": "mock-refresh-token"
}
auth_file = temp_config_dir / "auth.json" auth_file = temp_config_dir / "auth.json"
with open(auth_file, "w") as f: with open(auth_file, "w") as f:
@@ -90,15 +81,15 @@ def sample_bank_data():
"name": "Revolut", "name": "Revolut",
"bic": "REVOLT21", "bic": "REVOLT21",
"transaction_total_days": 90, "transaction_total_days": 90,
"countries": ["GB", "LT"] "countries": ["GB", "LT"],
}, },
{ {
"id": "BANCOBPI_BBPIPTPL", "id": "BANCOBPI_BBPIPTPL",
"name": "Banco BPI", "name": "Banco BPI",
"bic": "BBPIPTPL", "bic": "BBPIPTPL",
"transaction_total_days": 90, "transaction_total_days": 90,
"countries": ["PT"] "countries": ["PT"],
} },
] ]
@@ -111,7 +102,7 @@ def sample_account_data():
"status": "READY", "status": "READY",
"iban": "LT313250081177977789", "iban": "LT313250081177977789",
"created": "2024-02-13T23:56:00Z", "created": "2024-02-13T23:56:00Z",
"last_accessed": "2025-09-01T09:30:00Z" "last_accessed": "2025-09-01T09:30:00Z",
} }
@@ -125,13 +116,10 @@ def sample_transaction_data():
"internalTransactionId": "txn-123", "internalTransactionId": "txn-123",
"bookingDate": "2025-09-01", "bookingDate": "2025-09-01",
"valueDate": "2025-09-01", "valueDate": "2025-09-01",
"transactionAmount": { "transactionAmount": {"amount": "-10.50", "currency": "EUR"},
"amount": "-10.50", "remittanceInformationUnstructured": "Coffee Shop Payment",
"currency": "EUR"
},
"remittanceInformationUnstructured": "Coffee Shop Payment"
} }
], ],
"pending": [] "pending": [],
} }
} }

View File

@@ -1,4 +1,5 @@
"""Tests for accounts API endpoints.""" """Tests for accounts API endpoints."""
import pytest import pytest
import respx import respx
import httpx import httpx
@@ -10,15 +11,12 @@ class TestAccountsAPI:
"""Test account-related API endpoints.""" """Test account-related API endpoints."""
@respx.mock @respx.mock
def test_get_all_accounts_success(self, api_client, mock_config, mock_auth_token, sample_account_data): def test_get_all_accounts_success(
self, api_client, mock_config, mock_auth_token, sample_account_data
):
"""Test successful retrieval of all accounts.""" """Test successful retrieval of all accounts."""
requisitions_data = { requisitions_data = {
"results": [ "results": [{"id": "req-123", "accounts": ["test-account-123"]}]
{
"id": "req-123",
"accounts": ["test-account-123"]
}
]
} }
balances_data = { balances_data = {
@@ -26,28 +24,30 @@ class TestAccountsAPI:
{ {
"balanceAmount": {"amount": "100.50", "currency": "EUR"}, "balanceAmount": {"amount": "100.50", "currency": "EUR"},
"balanceType": "interimAvailable", "balanceType": "interimAvailable",
"lastChangeDateTime": "2025-09-01T09:30:00Z" "lastChangeDateTime": "2025-09-01T09:30:00Z",
} }
] ]
} }
# Mock GoCardless token creation # Mock GoCardless token creation
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) return_value=httpx.Response(
200, json={"access": "test-token", "refresh": "test-refresh"}
)
) )
# Mock GoCardless API calls # Mock GoCardless API calls
respx.get("https://bankaccountdata.gocardless.com/api/v2/requisitions/").mock( respx.get("https://bankaccountdata.gocardless.com/api/v2/requisitions/").mock(
return_value=httpx.Response(200, json=requisitions_data) return_value=httpx.Response(200, json=requisitions_data)
) )
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/").mock( respx.get(
return_value=httpx.Response(200, json=sample_account_data) "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/"
) ).mock(return_value=httpx.Response(200, json=sample_account_data))
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/").mock( respx.get(
return_value=httpx.Response(200, json=balances_data) "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/"
) ).mock(return_value=httpx.Response(200, json=balances_data))
with patch('leggend.config.config', mock_config): with patch("leggend.config.config", mock_config):
response = api_client.get("/api/v1/accounts") response = api_client.get("/api/v1/accounts")
assert response.status_code == 200 assert response.status_code == 200
@@ -61,31 +61,35 @@ class TestAccountsAPI:
assert account["balances"][0]["amount"] == 100.50 assert account["balances"][0]["amount"] == 100.50
@respx.mock @respx.mock
def test_get_account_details_success(self, api_client, mock_config, mock_auth_token, sample_account_data): def test_get_account_details_success(
self, api_client, mock_config, mock_auth_token, sample_account_data
):
"""Test successful retrieval of specific account details.""" """Test successful retrieval of specific account details."""
balances_data = { balances_data = {
"balances": [ "balances": [
{ {
"balanceAmount": {"amount": "250.75", "currency": "EUR"}, "balanceAmount": {"amount": "250.75", "currency": "EUR"},
"balanceType": "interimAvailable" "balanceType": "interimAvailable",
} }
] ]
} }
# Mock GoCardless token creation # Mock GoCardless token creation
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) return_value=httpx.Response(
200, json={"access": "test-token", "refresh": "test-refresh"}
)
) )
# Mock GoCardless API calls # Mock GoCardless API calls
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/").mock( respx.get(
return_value=httpx.Response(200, json=sample_account_data) "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/"
) ).mock(return_value=httpx.Response(200, json=sample_account_data))
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/").mock( respx.get(
return_value=httpx.Response(200, json=balances_data) "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/"
) ).mock(return_value=httpx.Response(200, json=balances_data))
with patch('leggend.config.config', mock_config): with patch("leggend.config.config", mock_config):
response = api_client.get("/api/v1/accounts/test-account-123") response = api_client.get("/api/v1/accounts/test-account-123")
assert response.status_code == 200 assert response.status_code == 200
@@ -97,33 +101,37 @@ class TestAccountsAPI:
assert len(account["balances"]) == 1 assert len(account["balances"]) == 1
@respx.mock @respx.mock
def test_get_account_balances_success(self, api_client, mock_config, mock_auth_token): def test_get_account_balances_success(
self, api_client, mock_config, mock_auth_token
):
"""Test successful retrieval of account balances.""" """Test successful retrieval of account balances."""
balances_data = { balances_data = {
"balances": [ "balances": [
{ {
"balanceAmount": {"amount": "1000.00", "currency": "EUR"}, "balanceAmount": {"amount": "1000.00", "currency": "EUR"},
"balanceType": "interimAvailable", "balanceType": "interimAvailable",
"lastChangeDateTime": "2025-09-01T10:00:00Z" "lastChangeDateTime": "2025-09-01T10:00:00Z",
}, },
{ {
"balanceAmount": {"amount": "950.00", "currency": "EUR"}, "balanceAmount": {"amount": "950.00", "currency": "EUR"},
"balanceType": "expected" "balanceType": "expected",
} },
] ]
} }
# Mock GoCardless token creation # Mock GoCardless token creation
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) return_value=httpx.Response(
200, json={"access": "test-token", "refresh": "test-refresh"}
)
) )
# Mock GoCardless API # Mock GoCardless API
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/").mock( respx.get(
return_value=httpx.Response(200, json=balances_data) "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/"
) ).mock(return_value=httpx.Response(200, json=balances_data))
with patch('leggend.config.config', mock_config): with patch("leggend.config.config", mock_config):
response = api_client.get("/api/v1/accounts/test-account-123/balances") response = api_client.get("/api/v1/accounts/test-account-123/balances")
assert response.status_code == 200 assert response.status_code == 200
@@ -135,23 +143,34 @@ class TestAccountsAPI:
assert data["data"][0]["balance_type"] == "interimAvailable" assert data["data"][0]["balance_type"] == "interimAvailable"
@respx.mock @respx.mock
def test_get_account_transactions_success(self, api_client, mock_config, mock_auth_token, sample_account_data, sample_transaction_data): def test_get_account_transactions_success(
self,
api_client,
mock_config,
mock_auth_token,
sample_account_data,
sample_transaction_data,
):
"""Test successful retrieval of account transactions.""" """Test successful retrieval of account transactions."""
# Mock GoCardless token creation # Mock GoCardless token creation
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) return_value=httpx.Response(
200, json={"access": "test-token", "refresh": "test-refresh"}
)
) )
# Mock GoCardless API calls # Mock GoCardless API calls
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/").mock( respx.get(
return_value=httpx.Response(200, json=sample_account_data) "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/"
) ).mock(return_value=httpx.Response(200, json=sample_account_data))
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/transactions/").mock( respx.get(
return_value=httpx.Response(200, json=sample_transaction_data) "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/transactions/"
) ).mock(return_value=httpx.Response(200, json=sample_transaction_data))
with patch('leggend.config.config', mock_config): with patch("leggend.config.config", mock_config):
response = api_client.get("/api/v1/accounts/test-account-123/transactions?summary_only=true") response = api_client.get(
"/api/v1/accounts/test-account-123/transactions?summary_only=true"
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -165,23 +184,34 @@ class TestAccountsAPI:
assert transaction["description"] == "Coffee Shop Payment" assert transaction["description"] == "Coffee Shop Payment"
@respx.mock @respx.mock
def test_get_account_transactions_full_details(self, api_client, mock_config, mock_auth_token, sample_account_data, sample_transaction_data): def test_get_account_transactions_full_details(
self,
api_client,
mock_config,
mock_auth_token,
sample_account_data,
sample_transaction_data,
):
"""Test retrieval of full transaction details.""" """Test retrieval of full transaction details."""
# Mock GoCardless token creation # Mock GoCardless token creation
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) return_value=httpx.Response(
200, json={"access": "test-token", "refresh": "test-refresh"}
)
) )
# Mock GoCardless API calls # Mock GoCardless API calls
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/").mock( respx.get(
return_value=httpx.Response(200, json=sample_account_data) "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/"
) ).mock(return_value=httpx.Response(200, json=sample_account_data))
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/transactions/").mock( respx.get(
return_value=httpx.Response(200, json=sample_transaction_data) "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/transactions/"
) ).mock(return_value=httpx.Response(200, json=sample_transaction_data))
with patch('leggend.config.config', mock_config): with patch("leggend.config.config", mock_config):
response = api_client.get("/api/v1/accounts/test-account-123/transactions?summary_only=false") response = api_client.get(
"/api/v1/accounts/test-account-123/transactions?summary_only=false"
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -200,14 +230,18 @@ class TestAccountsAPI:
with respx.mock: with respx.mock:
# Mock GoCardless token creation # Mock GoCardless token creation
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) return_value=httpx.Response(
200, json={"access": "test-token", "refresh": "test-refresh"}
)
) )
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/nonexistent/").mock( respx.get(
"https://bankaccountdata.gocardless.com/api/v2/accounts/nonexistent/"
).mock(
return_value=httpx.Response(404, json={"detail": "Account not found"}) return_value=httpx.Response(404, json={"detail": "Account not found"})
) )
with patch('leggend.config.config', mock_config): with patch("leggend.config.config", mock_config):
response = api_client.get("/api/v1/accounts/nonexistent") response = api_client.get("/api/v1/accounts/nonexistent")
assert response.status_code == 404 assert response.status_code == 404

View File

@@ -1,4 +1,5 @@
"""Tests for banks API endpoints.""" """Tests for banks API endpoints."""
import pytest import pytest
import respx import respx
import httpx import httpx
@@ -12,11 +13,15 @@ class TestBanksAPI:
"""Test bank-related API endpoints.""" """Test bank-related API endpoints."""
@respx.mock @respx.mock
def test_get_institutions_success(self, api_client, mock_config, mock_auth_token, sample_bank_data): def test_get_institutions_success(
self, api_client, mock_config, mock_auth_token, sample_bank_data
):
"""Test successful retrieval of bank institutions.""" """Test successful retrieval of bank institutions."""
# Mock GoCardless token creation/refresh # Mock GoCardless token creation/refresh
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) return_value=httpx.Response(
200, json={"access": "test-token", "refresh": "test-refresh"}
)
) )
# Mock GoCardless institutions API # Mock GoCardless institutions API
@@ -24,7 +29,7 @@ class TestBanksAPI:
return_value=httpx.Response(200, json=sample_bank_data) return_value=httpx.Response(200, json=sample_bank_data)
) )
with patch('leggend.config.config', mock_config): with patch("leggend.config.config", mock_config):
response = api_client.get("/api/v1/banks/institutions?country=PT") response = api_client.get("/api/v1/banks/institutions?country=PT")
assert response.status_code == 200 assert response.status_code == 200
@@ -39,7 +44,9 @@ class TestBanksAPI:
"""Test institutions endpoint with invalid country code.""" """Test institutions endpoint with invalid country code."""
# Mock GoCardless token creation # Mock GoCardless token creation
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) return_value=httpx.Response(
200, json={"access": "test-token", "refresh": "test-refresh"}
)
) )
# Mock empty institutions response for invalid country # Mock empty institutions response for invalid country
@@ -47,7 +54,7 @@ class TestBanksAPI:
return_value=httpx.Response(200, json=[]) return_value=httpx.Response(200, json=[])
) )
with patch('leggend.config.config', mock_config): with patch("leggend.config.config", mock_config):
response = api_client.get("/api/v1/banks/institutions?country=XX") response = api_client.get("/api/v1/banks/institutions?country=XX")
# Should still work but return empty or filtered results # Should still work but return empty or filtered results
@@ -61,12 +68,14 @@ class TestBanksAPI:
"institution_id": "REVOLUT_REVOLT21", "institution_id": "REVOLUT_REVOLT21",
"status": "CR", "status": "CR",
"created": "2025-09-02T00:00:00Z", "created": "2025-09-02T00:00:00Z",
"link": "https://example.com/auth" "link": "https://example.com/auth",
} }
# Mock GoCardless token creation # Mock GoCardless token creation
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) return_value=httpx.Response(
200, json={"access": "test-token", "refresh": "test-refresh"}
)
) )
# Mock GoCardless requisitions API # Mock GoCardless requisitions API
@@ -76,10 +85,10 @@ class TestBanksAPI:
request_data = { request_data = {
"institution_id": "REVOLUT_REVOLT21", "institution_id": "REVOLUT_REVOLT21",
"redirect_url": "http://localhost:8000/" "redirect_url": "http://localhost:8000/",
} }
with patch('leggend.config.config', mock_config): with patch("leggend.config.config", mock_config):
response = api_client.post("/api/v1/banks/connect", json=request_data) response = api_client.post("/api/v1/banks/connect", json=request_data)
assert response.status_code == 200 assert response.status_code == 200
@@ -98,14 +107,16 @@ class TestBanksAPI:
"institution_id": "REVOLUT_REVOLT21", "institution_id": "REVOLUT_REVOLT21",
"status": "LN", "status": "LN",
"created": "2025-09-02T00:00:00Z", "created": "2025-09-02T00:00:00Z",
"accounts": ["acc-123"] "accounts": ["acc-123"],
} }
] ]
} }
# Mock GoCardless token creation # Mock GoCardless token creation
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) return_value=httpx.Response(
200, json={"access": "test-token", "refresh": "test-refresh"}
)
) )
# Mock GoCardless requisitions API # Mock GoCardless requisitions API
@@ -113,7 +124,7 @@ class TestBanksAPI:
return_value=httpx.Response(200, json=requisitions_data) return_value=httpx.Response(200, json=requisitions_data)
) )
with patch('leggend.config.config', mock_config): with patch("leggend.config.config", mock_config):
response = api_client.get("/api/v1/banks/status") response = api_client.get("/api/v1/banks/status")
assert response.status_code == 200 assert response.status_code == 200
@@ -146,7 +157,7 @@ class TestBanksAPI:
return_value=httpx.Response(401, json={"detail": "Invalid credentials"}) return_value=httpx.Response(401, json={"detail": "Invalid credentials"})
) )
with patch('leggend.config.config', mock_config): with patch("leggend.config.config", mock_config):
response = api_client.get("/api/v1/banks/institutions") response = api_client.get("/api/v1/banks/institutions")
assert response.status_code == 500 assert response.status_code == 500

View File

@@ -1,4 +1,5 @@
"""Tests for CLI API client.""" """Tests for CLI API client."""
import pytest import pytest
import requests_mock import requests_mock
from unittest.mock import patch from unittest.mock import patch
@@ -37,7 +38,7 @@ class TestLeggendAPIClient:
api_response = { api_response = {
"success": True, "success": True,
"data": sample_bank_data, "data": sample_bank_data,
"message": "Found 2 institutions for PT" "message": "Found 2 institutions for PT",
} }
with requests_mock.Mocker() as m: with requests_mock.Mocker() as m:
@@ -54,7 +55,7 @@ class TestLeggendAPIClient:
api_response = { api_response = {
"success": True, "success": True,
"data": [sample_account_data], "data": [sample_account_data],
"message": "Retrieved 1 accounts" "message": "Retrieved 1 accounts",
} }
with requests_mock.Mocker() as m: with requests_mock.Mocker() as m:
@@ -71,7 +72,7 @@ class TestLeggendAPIClient:
api_response = { api_response = {
"success": True, "success": True,
"data": {"sync_started": True, "force": False}, "data": {"sync_started": True, "force": False},
"message": "Started sync for all accounts" "message": "Started sync for all accounts",
} }
with requests_mock.Mocker() as m: with requests_mock.Mocker() as m:
@@ -92,8 +93,11 @@ class TestLeggendAPIClient:
client = LeggendAPIClient("http://localhost:8000") client = LeggendAPIClient("http://localhost:8000")
with requests_mock.Mocker() as m: with requests_mock.Mocker() as m:
m.get("http://localhost:8000/api/v1/accounts", status_code=500, m.get(
json={"detail": "Internal server error"}) "http://localhost:8000/api/v1/accounts",
status_code=500,
json={"detail": "Internal server error"},
)
with pytest.raises(Exception): with pytest.raises(Exception):
client.get_accounts() client.get_accounts()
@@ -107,7 +111,7 @@ class TestLeggendAPIClient:
def test_environment_variable_url(self): def test_environment_variable_url(self):
"""Test using environment variable for API URL.""" """Test using environment variable for API URL."""
with patch.dict('os.environ', {'LEGGEND_API_URL': 'http://env-host:7000'}): with patch.dict("os.environ", {"LEGGEND_API_URL": "http://env-host:7000"}):
client = LeggendAPIClient() client = LeggendAPIClient()
assert client.base_url == "http://env-host:7000" assert client.base_url == "http://env-host:7000"
@@ -118,7 +122,7 @@ class TestLeggendAPIClient:
api_response = { api_response = {
"success": True, "success": True,
"data": {"sync_started": True, "force": True}, "data": {"sync_started": True, "force": True},
"message": "Started sync for 2 specific accounts" "message": "Started sync for 2 specific accounts",
} }
with requests_mock.Mocker() as m: with requests_mock.Mocker() as m:
@@ -138,8 +142,8 @@ class TestLeggendAPIClient:
"enabled": True, "enabled": True,
"hour": 3, "hour": 3,
"minute": 0, "minute": 0,
"next_scheduled_sync": "2025-09-03T03:00:00Z" "next_scheduled_sync": "2025-09-03T03:00:00Z",
} },
} }
with requests_mock.Mocker() as m: with requests_mock.Mocker() as m:

View File

@@ -1,4 +1,5 @@
"""Tests for configuration management.""" """Tests for configuration management."""
import pytest import pytest
import tempfile import tempfile
from pathlib import Path from pathlib import Path
@@ -23,16 +24,15 @@ class TestConfig:
"gocardless": { "gocardless": {
"key": "test-key", "key": "test-key",
"secret": "test-secret", "secret": "test-secret",
"url": "https://test.example.com" "url": "https://test.example.com",
}, },
"database": { "database": {"sqlite": True},
"sqlite": True
}
} }
config_file = temp_config_dir / "config.toml" config_file = temp_config_dir / "config.toml"
with open(config_file, "wb") as f: with open(config_file, "wb") as f:
import tomli_w import tomli_w
tomli_w.dump(config_data, f) tomli_w.dump(config_data, f)
config = Config() config = Config()
@@ -56,12 +56,7 @@ class TestConfig:
def test_save_config_success(self, temp_config_dir): def test_save_config_success(self, temp_config_dir):
"""Test successful configuration saving.""" """Test successful configuration saving."""
config_data = { config_data = {"gocardless": {"key": "new-key", "secret": "new-secret"}}
"gocardless": {
"key": "new-key",
"secret": "new-secret"
}
}
config_file = temp_config_dir / "new_config.toml" config_file = temp_config_dir / "new_config.toml"
config = Config() config = Config()
@@ -73,6 +68,7 @@ class TestConfig:
assert config_file.exists() assert config_file.exists()
import tomllib import tomllib
with open(config_file, "rb") as f: with open(config_file, "rb") as f:
saved_data = tomllib.load(f) saved_data = tomllib.load(f)
@@ -82,12 +78,13 @@ class TestConfig:
"""Test updating configuration values.""" """Test updating configuration values."""
initial_config = { initial_config = {
"gocardless": {"key": "old-key"}, "gocardless": {"key": "old-key"},
"database": {"sqlite": True} "database": {"sqlite": True},
} }
config_file = temp_config_dir / "config.toml" config_file = temp_config_dir / "config.toml"
with open(config_file, "wb") as f: with open(config_file, "wb") as f:
import tomli_w import tomli_w
tomli_w.dump(initial_config, f) tomli_w.dump(initial_config, f)
config = Config() config = Config()
@@ -100,19 +97,19 @@ class TestConfig:
# Verify it was saved to file # Verify it was saved to file
import tomllib import tomllib
with open(config_file, "rb") as f: with open(config_file, "rb") as f:
saved_data = tomllib.load(f) saved_data = tomllib.load(f)
assert saved_data["gocardless"]["key"] == "new-key" assert saved_data["gocardless"]["key"] == "new-key"
def test_update_section_success(self, temp_config_dir): def test_update_section_success(self, temp_config_dir):
"""Test updating entire configuration section.""" """Test updating entire configuration section."""
initial_config = { initial_config = {"database": {"sqlite": True}}
"database": {"sqlite": True}
}
config_file = temp_config_dir / "config.toml" config_file = temp_config_dir / "config.toml"
with open(config_file, "wb") as f: with open(config_file, "wb") as f:
import tomli_w import tomli_w
tomli_w.dump(initial_config, f) tomli_w.dump(initial_config, f)
config = Config() config = Config()
@@ -144,7 +141,7 @@ class TestConfig:
"enabled": False, "enabled": False,
"hour": 6, "hour": 6,
"minute": 30, "minute": 30,
"cron": "0 6 * * 1-5" "cron": "0 6 * * 1-5",
} }
} }
} }
@@ -161,11 +158,13 @@ class TestConfig:
def test_environment_variable_config_path(self): def test_environment_variable_config_path(self):
"""Test using environment variable for config path.""" """Test using environment variable for config path."""
with patch.dict('os.environ', {'LEGGEN_CONFIG_FILE': '/custom/path/config.toml'}): with patch.dict(
"os.environ", {"LEGGEN_CONFIG_FILE": "/custom/path/config.toml"}
):
config = Config() config = Config()
config._config = None config._config = None
with patch('builtins.open', side_effect=FileNotFoundError): with patch("builtins.open", side_effect=FileNotFoundError):
with pytest.raises(FileNotFoundError): with pytest.raises(FileNotFoundError):
config.load_config() config.load_config()
@@ -174,7 +173,7 @@ class TestConfig:
custom_config = { custom_config = {
"notifications": { "notifications": {
"discord": {"webhook": "https://discord.webhook", "enabled": True}, "discord": {"webhook": "https://discord.webhook", "enabled": True},
"telegram": {"token": "bot-token", "chat_id": 123} "telegram": {"token": "bot-token", "chat_id": 123},
} }
} }
@@ -190,7 +189,7 @@ class TestConfig:
custom_config = { custom_config = {
"filters": { "filters": {
"case-insensitive": {"salary": "SALARY", "bills": "BILL"}, "case-insensitive": {"salary": "SALARY", "bills": "BILL"},
"amount_threshold": 100.0 "amount_threshold": 100.0,
} }
} }

View File

@@ -1,4 +1,5 @@
"""Tests for background scheduler.""" """Tests for background scheduler."""
import pytest import pytest
import asyncio import asyncio
from unittest.mock import Mock, patch, AsyncMock, MagicMock from unittest.mock import Mock, patch, AsyncMock, MagicMock
@@ -16,22 +17,18 @@ class TestBackgroundScheduler:
@pytest.fixture @pytest.fixture
def mock_config(self): def mock_config(self):
"""Mock configuration for scheduler tests.""" """Mock configuration for scheduler tests."""
return { return {"sync": {"enabled": True, "hour": 3, "minute": 0, "cron": None}}
"sync": {
"enabled": True,
"hour": 3,
"minute": 0,
"cron": None
}
}
@pytest.fixture @pytest.fixture
def scheduler(self): def scheduler(self):
"""Create scheduler instance for testing.""" """Create scheduler instance for testing."""
with patch('leggend.background.scheduler.SyncService'), \ with (
patch('leggend.background.scheduler.config') as mock_config: patch("leggend.background.scheduler.SyncService"),
patch("leggend.background.scheduler.config") as mock_config,
mock_config.scheduler_config = {"sync": {"enabled": True, "hour": 3, "minute": 0}} ):
mock_config.scheduler_config = {
"sync": {"enabled": True, "hour": 3, "minute": 0}
}
# Create scheduler and replace its AsyncIO scheduler with a mock # Create scheduler and replace its AsyncIO scheduler with a mock
scheduler = BackgroundScheduler() scheduler = BackgroundScheduler()
@@ -43,7 +40,7 @@ class TestBackgroundScheduler:
def test_scheduler_start_default_config(self, scheduler, mock_config): def test_scheduler_start_default_config(self, scheduler, mock_config):
"""Test starting scheduler with default configuration.""" """Test starting scheduler with default configuration."""
with patch('leggend.config.config') as mock_config_obj: with patch("leggend.config.config") as mock_config_obj:
mock_config_obj.scheduler_config = mock_config mock_config_obj.scheduler_config = mock_config
# Mock the job that gets added # Mock the job that gets added
@@ -60,13 +57,12 @@ class TestBackgroundScheduler:
def test_scheduler_start_disabled(self, scheduler): def test_scheduler_start_disabled(self, scheduler):
"""Test scheduler behavior when sync is disabled.""" """Test scheduler behavior when sync is disabled."""
disabled_config = { disabled_config = {"sync": {"enabled": False}}
"sync": {"enabled": False}
}
with patch.object(scheduler, 'scheduler') as mock_scheduler, \
patch('leggend.background.scheduler.config') as mock_config_obj:
with (
patch.object(scheduler, "scheduler") as mock_scheduler,
patch("leggend.background.scheduler.config") as mock_config_obj,
):
mock_config_obj.scheduler_config = disabled_config mock_config_obj.scheduler_config = disabled_config
mock_scheduler.running = False mock_scheduler.running = False
@@ -82,11 +78,11 @@ class TestBackgroundScheduler:
cron_config = { cron_config = {
"sync": { "sync": {
"enabled": True, "enabled": True,
"cron": "0 6 * * 1-5" # 6 AM on weekdays "cron": "0 6 * * 1-5", # 6 AM on weekdays
} }
} }
with patch('leggend.config.config') as mock_config_obj: with patch("leggend.config.config") as mock_config_obj:
mock_config_obj.scheduler_config = cron_config mock_config_obj.scheduler_config = cron_config
scheduler.start() scheduler.start()
@@ -96,20 +92,16 @@ class TestBackgroundScheduler:
scheduler.scheduler.add_job.assert_called_once() scheduler.scheduler.add_job.assert_called_once()
# Verify job was added with correct ID # Verify job was added with correct ID
call_args = scheduler.scheduler.add_job.call_args call_args = scheduler.scheduler.add_job.call_args
assert call_args.kwargs['id'] == 'daily_sync' assert call_args.kwargs["id"] == "daily_sync"
def test_scheduler_start_invalid_cron(self, scheduler): def test_scheduler_start_invalid_cron(self, scheduler):
"""Test handling of invalid cron expressions.""" """Test handling of invalid cron expressions."""
invalid_cron_config = { invalid_cron_config = {"sync": {"enabled": True, "cron": "invalid cron"}}
"sync": {
"enabled": True,
"cron": "invalid cron"
}
}
with patch.object(scheduler, 'scheduler') as mock_scheduler, \
patch('leggend.background.scheduler.config') as mock_config_obj:
with (
patch.object(scheduler, "scheduler") as mock_scheduler,
patch("leggend.background.scheduler.config") as mock_config_obj,
):
mock_config_obj.scheduler_config = invalid_cron_config mock_config_obj.scheduler_config = invalid_cron_config
mock_scheduler.running = False mock_scheduler.running = False
@@ -133,11 +125,7 @@ class TestBackgroundScheduler:
scheduler.scheduler.running = True scheduler.scheduler.running = True
# Reschedule with new config # Reschedule with new config
new_config = { new_config = {"enabled": True, "hour": 6, "minute": 30}
"enabled": True,
"hour": 6,
"minute": 30
}
scheduler.reschedule_sync(new_config) scheduler.reschedule_sync(new_config)
@@ -202,10 +190,10 @@ class TestBackgroundScheduler:
def test_scheduler_job_max_instances(self, scheduler, mock_config): def test_scheduler_job_max_instances(self, scheduler, mock_config):
"""Test that sync jobs have max_instances=1.""" """Test that sync jobs have max_instances=1."""
with patch('leggend.config.config') as mock_config_obj: with patch("leggend.config.config") as mock_config_obj:
mock_config_obj.scheduler_config = mock_config mock_config_obj.scheduler_config = mock_config
scheduler.start() scheduler.start()
# Verify add_job was called with max_instances=1 # Verify add_job was called with max_instances=1
call_args = scheduler.scheduler.add_job.call_args call_args = scheduler.scheduler.add_job.call_args
assert call_args.kwargs['max_instances'] == 1 assert call_args.kwargs["max_instances"] == 1