mirror of
https://github.com/elisiariocouto/leggen.git
synced 2025-12-13 11:22:21 +00:00
chore: Implement code review suggestions and format code.
This commit is contained in:
committed by
Elisiário Couto
parent
47164e8546
commit
de3da84dff
@@ -1,32 +0,0 @@
|
||||
FROM python:3.13-alpine AS builder
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=uv.lock,target=uv.lock \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
uv sync --frozen --no-install-project --no-editable
|
||||
|
||||
COPY . /app
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv sync --frozen --no-editable --no-group dev
|
||||
|
||||
FROM python:3.13-alpine
|
||||
|
||||
LABEL org.opencontainers.image.source="https://github.com/elisiariocouto/leggen"
|
||||
LABEL org.opencontainers.image.authors="Elisiário Couto <elisiario@couto.io>"
|
||||
LABEL org.opencontainers.image.licenses="MIT"
|
||||
LABEL org.opencontainers.image.title="leggend"
|
||||
LABEL org.opencontainers.image.description="Leggen API service"
|
||||
LABEL org.opencontainers.image.url="https://github.com/elisiariocouto/leggen"
|
||||
|
||||
WORKDIR /app
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
|
||||
COPY --from=builder --chown=app:app /app/.venv /app/.venv
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
ENTRYPOINT ["/app/.venv/bin/leggend"]
|
||||
15
compose.yml
15
compose.yml
@@ -3,7 +3,6 @@ services:
|
||||
leggend:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.leggend
|
||||
restart: "unless-stopped"
|
||||
ports:
|
||||
- "127.0.0.1:8000:8000"
|
||||
@@ -18,20 +17,6 @@ services:
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
|
||||
# CLI for one-off operations (uses leggend API)
|
||||
leggen:
|
||||
image: elisiariocouto/leggen:latest
|
||||
command: sync --wait
|
||||
restart: "no"
|
||||
volumes:
|
||||
- "./leggen:/root/.config/leggen"
|
||||
- "./db:/app"
|
||||
environment:
|
||||
- LEGGEND_API_URL=http://leggend:8000
|
||||
depends_on:
|
||||
leggend:
|
||||
condition: service_healthy
|
||||
|
||||
nocodb:
|
||||
image: nocodb/nocodb:latest
|
||||
restart: "unless-stopped"
|
||||
|
||||
@@ -10,12 +10,13 @@ class LeggendAPIClient:
|
||||
"""Client for communicating with the leggend FastAPI service"""
|
||||
|
||||
def __init__(self, base_url: Optional[str] = None):
|
||||
self.base_url = base_url or os.environ.get("LEGGEND_API_URL", "http://localhost:8000")
|
||||
self.base_url = base_url or os.environ.get(
|
||||
"LEGGEND_API_URL", "http://localhost:8000"
|
||||
)
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json"
|
||||
})
|
||||
self.session.headers.update(
|
||||
{"Content-Type": "application/json", "Accept": "application/json"}
|
||||
)
|
||||
|
||||
def _make_request(self, method: str, endpoint: str, **kwargs) -> Dict[str, Any]:
|
||||
"""Make HTTP request to the API"""
|
||||
@@ -53,15 +54,19 @@ class LeggendAPIClient:
|
||||
# Bank endpoints
|
||||
def get_institutions(self, country: str = "PT") -> List[Dict[str, Any]]:
|
||||
"""Get bank institutions for a country"""
|
||||
response = self._make_request("GET", "/api/v1/banks/institutions", params={"country": country})
|
||||
response = self._make_request(
|
||||
"GET", "/api/v1/banks/institutions", params={"country": country}
|
||||
)
|
||||
return response.get("data", [])
|
||||
|
||||
def connect_to_bank(self, institution_id: str, redirect_url: str = "http://localhost:8000/") -> Dict[str, Any]:
|
||||
def connect_to_bank(
|
||||
self, institution_id: str, redirect_url: str = "http://localhost:8000/"
|
||||
) -> Dict[str, Any]:
|
||||
"""Connect to a bank"""
|
||||
response = self._make_request(
|
||||
"POST",
|
||||
"/api/v1/banks/connect",
|
||||
json={"institution_id": institution_id, "redirect_url": redirect_url}
|
||||
json={"institution_id": institution_id, "redirect_url": redirect_url},
|
||||
)
|
||||
return response.get("data", {})
|
||||
|
||||
@@ -91,17 +96,21 @@ class LeggendAPIClient:
|
||||
response = self._make_request("GET", f"/api/v1/accounts/{account_id}/balances")
|
||||
return response.get("data", [])
|
||||
|
||||
def get_account_transactions(self, account_id: str, limit: int = 100, summary_only: bool = False) -> List[Dict[str, Any]]:
|
||||
def get_account_transactions(
|
||||
self, account_id: str, limit: int = 100, summary_only: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get account transactions"""
|
||||
response = self._make_request(
|
||||
"GET",
|
||||
f"/api/v1/accounts/{account_id}/transactions",
|
||||
params={"limit": limit, "summary_only": summary_only}
|
||||
params={"limit": limit, "summary_only": summary_only},
|
||||
)
|
||||
return response.get("data", [])
|
||||
|
||||
# Transaction endpoints
|
||||
def get_all_transactions(self, limit: int = 100, summary_only: bool = True, **filters) -> List[Dict[str, Any]]:
|
||||
def get_all_transactions(
|
||||
self, limit: int = 100, summary_only: bool = True, **filters
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get all transactions with optional filters"""
|
||||
params = {"limit": limit, "summary_only": summary_only}
|
||||
params.update(filters)
|
||||
@@ -109,13 +118,17 @@ class LeggendAPIClient:
|
||||
response = self._make_request("GET", "/api/v1/transactions", params=params)
|
||||
return response.get("data", [])
|
||||
|
||||
def get_transaction_stats(self, days: int = 30, account_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
def get_transaction_stats(
|
||||
self, days: int = 30, account_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get transaction statistics"""
|
||||
params = {"days": days}
|
||||
if account_id:
|
||||
params["account_id"] = account_id
|
||||
|
||||
response = self._make_request("GET", "/api/v1/transactions/stats", params=params)
|
||||
response = self._make_request(
|
||||
"GET", "/api/v1/transactions/stats", params=params
|
||||
)
|
||||
return response.get("data", {})
|
||||
|
||||
# Sync endpoints
|
||||
@@ -124,7 +137,9 @@ class LeggendAPIClient:
|
||||
response = self._make_request("GET", "/api/v1/sync/status")
|
||||
return response.get("data", {})
|
||||
|
||||
def trigger_sync(self, account_ids: Optional[List[str]] = None, force: bool = False) -> Dict[str, Any]:
|
||||
def trigger_sync(
|
||||
self, account_ids: Optional[List[str]] = None, force: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Trigger a sync"""
|
||||
data = {"force": force}
|
||||
if account_ids:
|
||||
@@ -133,7 +148,9 @@ class LeggendAPIClient:
|
||||
response = self._make_request("POST", "/api/v1/sync", json=data)
|
||||
return response.get("data", {})
|
||||
|
||||
def sync_now(self, account_ids: Optional[List[str]] = None, force: bool = False) -> Dict[str, Any]:
|
||||
def sync_now(
|
||||
self, account_ids: Optional[List[str]] = None, force: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Run sync synchronously"""
|
||||
data = {"force": force}
|
||||
if account_ids:
|
||||
@@ -147,7 +164,13 @@ class LeggendAPIClient:
|
||||
response = self._make_request("GET", "/api/v1/sync/scheduler")
|
||||
return response.get("data", {})
|
||||
|
||||
def update_scheduler_config(self, enabled: bool = True, hour: int = 3, minute: int = 0, cron: Optional[str] = None) -> Dict[str, Any]:
|
||||
def update_scheduler_config(
|
||||
self,
|
||||
enabled: bool = True,
|
||||
hour: int = 3,
|
||||
minute: int = 0,
|
||||
cron: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Update scheduler configuration"""
|
||||
data = {"enabled": enabled, "hour": hour, "minute": minute}
|
||||
if cron:
|
||||
|
||||
@@ -15,7 +15,9 @@ def balances(ctx: click.Context):
|
||||
|
||||
# Check if leggend service is available
|
||||
if not api_client.health_check():
|
||||
click.echo("Error: Cannot connect to leggend service. Please ensure it's running.")
|
||||
click.echo(
|
||||
"Error: Cannot connect to leggend service. Please ensure it's running."
|
||||
)
|
||||
return
|
||||
|
||||
accounts = api_client.get_accounts()
|
||||
@@ -24,11 +26,7 @@ def balances(ctx: click.Context):
|
||||
for account in accounts:
|
||||
for balance in account.get("balances", []):
|
||||
amount = round(float(balance["amount"]), 2)
|
||||
symbol = (
|
||||
"€"
|
||||
if balance["currency"] == "EUR"
|
||||
else f" {balance['currency']}"
|
||||
)
|
||||
symbol = "€" if balance["currency"] == "EUR" else f" {balance['currency']}"
|
||||
amount_str = f"{amount}{symbol}"
|
||||
date = (
|
||||
datefmt(balance.get("last_change_date"))
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
import os
|
||||
|
||||
import click
|
||||
|
||||
from leggen.main import cli
|
||||
|
||||
cmd_folder = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
|
||||
class BankGroup(click.Group):
|
||||
def list_commands(self, ctx):
|
||||
rv = []
|
||||
for filename in os.listdir(cmd_folder):
|
||||
if filename.endswith(".py") and not filename.startswith("__init__"):
|
||||
if filename == "list_banks.py":
|
||||
rv.append("list")
|
||||
else:
|
||||
rv.append(filename[:-3])
|
||||
rv.sort()
|
||||
return rv
|
||||
|
||||
def get_command(self, ctx, name):
|
||||
try:
|
||||
if name == "list":
|
||||
name = "list_banks"
|
||||
mod = __import__(f"leggen.commands.bank.{name}", None, None, [name])
|
||||
except ImportError:
|
||||
return
|
||||
return getattr(mod, name)
|
||||
|
||||
|
||||
@cli.group(cls=BankGroup)
|
||||
@click.pass_context
|
||||
def bank(ctx):
|
||||
"""Manage banks connections"""
|
||||
return
|
||||
@@ -16,7 +16,9 @@ def add(ctx):
|
||||
|
||||
# Check if leggend service is available
|
||||
if not api_client.health_check():
|
||||
click.echo("Error: Cannot connect to leggend service. Please ensure it's running.")
|
||||
click.echo(
|
||||
"Error: Cannot connect to leggend service. Please ensure it's running."
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -65,11 +67,15 @@ def add(ctx):
|
||||
save_file(f"req_{result['id']}.json", result)
|
||||
|
||||
success("Bank connection request created successfully!")
|
||||
warning(f"Please open the following URL in your browser to complete the authorization:")
|
||||
warning(
|
||||
f"Please open the following URL in your browser to complete the authorization:"
|
||||
)
|
||||
click.echo(f"\n{result['link']}\n")
|
||||
|
||||
info(f"Requisition ID: {result['id']}")
|
||||
info("After completing the authorization, you can check the connection status with 'leggen status'")
|
||||
info(
|
||||
"After completing the authorization, you can check the connection status with 'leggen status'"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"Error: Failed to connect to bank: {str(e)}")
|
||||
|
||||
@@ -15,7 +15,9 @@ def status(ctx: click.Context):
|
||||
|
||||
# Check if leggend service is available
|
||||
if not api_client.health_check():
|
||||
click.echo("Error: Cannot connect to leggend service. Please ensure it's running.")
|
||||
click.echo(
|
||||
"Error: Cannot connect to leggend service. Please ensure it's running."
|
||||
)
|
||||
return
|
||||
|
||||
# Get bank connection status
|
||||
|
||||
@@ -6,8 +6,8 @@ from leggen.utils.text import error, info, success
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option('--wait', is_flag=True, help='Wait for sync to complete (synchronous)')
|
||||
@click.option('--force', is_flag=True, help='Force sync even if already running')
|
||||
@click.option("--wait", is_flag=True, help="Wait for sync to complete (synchronous)")
|
||||
@click.option("--force", is_flag=True, help="Force sync even if already running")
|
||||
@click.pass_context
|
||||
def sync(ctx: click.Context, wait: bool, force: bool):
|
||||
"""
|
||||
@@ -31,17 +31,17 @@ def sync(ctx: click.Context, wait: bool, force: bool):
|
||||
info(f"Accounts processed: {result.get('accounts_processed', 0)}")
|
||||
info(f"Transactions added: {result.get('transactions_added', 0)}")
|
||||
info(f"Balances updated: {result.get('balances_updated', 0)}")
|
||||
if result.get('duration_seconds'):
|
||||
if result.get("duration_seconds"):
|
||||
info(f"Duration: {result['duration_seconds']:.2f} seconds")
|
||||
|
||||
if result.get('errors'):
|
||||
if result.get("errors"):
|
||||
error(f"Errors encountered: {len(result['errors'])}")
|
||||
for err in result['errors']:
|
||||
for err in result["errors"]:
|
||||
error(f" - {err}")
|
||||
else:
|
||||
error("Sync failed")
|
||||
if result.get('errors'):
|
||||
for err in result['errors']:
|
||||
if result.get("errors"):
|
||||
for err in result["errors"]:
|
||||
error(f" - {err}")
|
||||
else:
|
||||
# Trigger async sync
|
||||
@@ -50,7 +50,9 @@ def sync(ctx: click.Context, wait: bool, force: bool):
|
||||
|
||||
if result.get("sync_started"):
|
||||
success("Sync started successfully in the background")
|
||||
info("Use 'leggen sync --wait' to run synchronously or check status with API")
|
||||
info(
|
||||
"Use 'leggen sync --wait' to run synchronously or check status with API"
|
||||
)
|
||||
else:
|
||||
error("Failed to start sync")
|
||||
|
||||
|
||||
@@ -7,7 +7,9 @@ from leggen.utils.text import datefmt, info, print_table
|
||||
|
||||
@cli.command()
|
||||
@click.option("-a", "--account", type=str, help="Account ID")
|
||||
@click.option("-l", "--limit", type=int, default=50, help="Number of transactions to show")
|
||||
@click.option(
|
||||
"-l", "--limit", type=int, default=50, help="Number of transactions to show"
|
||||
)
|
||||
@click.option("--full", is_flag=True, help="Show full transaction details")
|
||||
@click.pass_context
|
||||
def transactions(ctx: click.Context, account: str, limit: int, full: bool):
|
||||
@@ -22,7 +24,9 @@ def transactions(ctx: click.Context, account: str, limit: int, full: bool):
|
||||
|
||||
# Check if leggend service is available
|
||||
if not api_client.health_check():
|
||||
click.echo("Error: Cannot connect to leggend service. Please ensure it's running.")
|
||||
click.echo(
|
||||
"Error: Cannot connect to leggend service. Please ensure it's running."
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -39,9 +43,7 @@ def transactions(ctx: click.Context, account: str, limit: int, full: bool):
|
||||
else:
|
||||
# Get all transactions
|
||||
transactions_data = api_client.get_all_transactions(
|
||||
limit=limit,
|
||||
summary_only=not full,
|
||||
account_id=account
|
||||
limit=limit, summary_only=not full, account_id=account
|
||||
)
|
||||
|
||||
# Format transactions for display
|
||||
@@ -49,24 +51,32 @@ def transactions(ctx: click.Context, account: str, limit: int, full: bool):
|
||||
# Full transaction details
|
||||
formatted_transactions = []
|
||||
for txn in transactions_data:
|
||||
formatted_transactions.append({
|
||||
formatted_transactions.append(
|
||||
{
|
||||
"ID": txn["internal_transaction_id"][:12] + "...",
|
||||
"Date": datefmt(txn["transaction_date"]),
|
||||
"Description": txn["description"][:50] + "..." if len(txn["description"]) > 50 else txn["description"],
|
||||
"Description": txn["description"][:50] + "..."
|
||||
if len(txn["description"]) > 50
|
||||
else txn["description"],
|
||||
"Amount": f"{txn['transaction_value']:.2f} {txn['transaction_currency']}",
|
||||
"Status": txn["transaction_status"].upper(),
|
||||
"Account": txn["account_id"][:8] + "...",
|
||||
})
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Summary view
|
||||
formatted_transactions = []
|
||||
for txn in transactions_data:
|
||||
formatted_transactions.append({
|
||||
formatted_transactions.append(
|
||||
{
|
||||
"Date": datefmt(txn["date"]),
|
||||
"Description": txn["description"][:60] + "..." if len(txn["description"]) > 60 else txn["description"],
|
||||
"Description": txn["description"][:60] + "..."
|
||||
if len(txn["description"]) > 60
|
||||
else txn["description"],
|
||||
"Amount": f"{txn['amount']:.2f} {txn['currency']}",
|
||||
"Status": txn["status"].upper(),
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
if formatted_transactions:
|
||||
print_table(formatted_transactions)
|
||||
|
||||
@@ -90,10 +90,10 @@ class Group(click.Group):
|
||||
@click.option(
|
||||
"--api-url",
|
||||
type=str,
|
||||
default=None,
|
||||
default="http://localhost:8000",
|
||||
envvar="LEGGEND_API_URL",
|
||||
show_envvar=True,
|
||||
help="URL of the leggend API service (default: http://localhost:8000)",
|
||||
help="URL of the leggend API service",
|
||||
)
|
||||
@click.group(
|
||||
cls=Group,
|
||||
|
||||
@@ -6,19 +6,19 @@ from pydantic import BaseModel
|
||||
|
||||
class AccountBalance(BaseModel):
|
||||
"""Account balance model"""
|
||||
|
||||
amount: float
|
||||
currency: str
|
||||
balance_type: str
|
||||
last_change_date: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
json_encoders = {datetime: lambda v: v.isoformat() if v else None}
|
||||
|
||||
|
||||
class AccountDetails(BaseModel):
|
||||
"""Account details model"""
|
||||
|
||||
id: str
|
||||
institution_id: str
|
||||
status: str
|
||||
@@ -30,13 +30,12 @@ class AccountDetails(BaseModel):
|
||||
balances: List[AccountBalance] = []
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
json_encoders = {datetime: lambda v: v.isoformat() if v else None}
|
||||
|
||||
|
||||
class Transaction(BaseModel):
|
||||
"""Transaction model"""
|
||||
|
||||
internal_transaction_id: str
|
||||
institution_id: str
|
||||
iban: Optional[str] = None
|
||||
@@ -49,13 +48,12 @@ class Transaction(BaseModel):
|
||||
raw_transaction: Dict[str, Any]
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat()
|
||||
}
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
|
||||
|
||||
class TransactionSummary(BaseModel):
|
||||
"""Transaction summary for lists"""
|
||||
|
||||
internal_transaction_id: str
|
||||
date: datetime
|
||||
description: str
|
||||
@@ -65,6 +63,4 @@ class TransactionSummary(BaseModel):
|
||||
account_id: str
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat()
|
||||
}
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
|
||||
@@ -6,6 +6,7 @@ from pydantic import BaseModel
|
||||
|
||||
class BankInstitution(BaseModel):
|
||||
"""Bank institution model"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
bic: Optional[str] = None
|
||||
@@ -16,12 +17,14 @@ class BankInstitution(BaseModel):
|
||||
|
||||
class BankConnectionRequest(BaseModel):
|
||||
"""Request to connect to a bank"""
|
||||
|
||||
institution_id: str
|
||||
redirect_url: Optional[str] = "http://localhost:8000/"
|
||||
|
||||
|
||||
class BankRequisition(BaseModel):
|
||||
"""Bank requisition/connection model"""
|
||||
|
||||
id: str
|
||||
institution_id: str
|
||||
status: str
|
||||
@@ -31,13 +34,12 @@ class BankRequisition(BaseModel):
|
||||
accounts: List[str] = []
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat()
|
||||
}
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
|
||||
|
||||
class BankConnectionStatus(BaseModel):
|
||||
"""Bank connection status response"""
|
||||
|
||||
bank_id: str
|
||||
bank_name: str
|
||||
status: str
|
||||
@@ -47,6 +49,4 @@ class BankConnectionStatus(BaseModel):
|
||||
accounts_count: int
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat()
|
||||
}
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
|
||||
@@ -6,6 +6,7 @@ from pydantic import BaseModel
|
||||
|
||||
class APIResponse(BaseModel):
|
||||
"""Base API response model"""
|
||||
|
||||
success: bool = True
|
||||
message: Optional[str] = None
|
||||
data: Optional[Any] = None
|
||||
@@ -13,6 +14,7 @@ class APIResponse(BaseModel):
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Error response model"""
|
||||
|
||||
success: bool = False
|
||||
message: str
|
||||
error_code: Optional[str] = None
|
||||
@@ -21,6 +23,7 @@ class ErrorResponse(BaseModel):
|
||||
|
||||
class PaginatedResponse(BaseModel):
|
||||
"""Paginated response model"""
|
||||
|
||||
success: bool = True
|
||||
data: list
|
||||
pagination: Dict[str, Any]
|
||||
|
||||
@@ -5,12 +5,14 @@ from pydantic import BaseModel
|
||||
|
||||
class DiscordConfig(BaseModel):
|
||||
"""Discord notification configuration"""
|
||||
|
||||
webhook: str
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class TelegramConfig(BaseModel):
|
||||
"""Telegram notification configuration"""
|
||||
|
||||
token: str
|
||||
chat_id: int
|
||||
enabled: bool = True
|
||||
@@ -18,6 +20,7 @@ class TelegramConfig(BaseModel):
|
||||
|
||||
class NotificationFilters(BaseModel):
|
||||
"""Notification filters configuration"""
|
||||
|
||||
case_insensitive: Dict[str, str] = {}
|
||||
case_sensitive: Optional[Dict[str, str]] = None
|
||||
amount_threshold: Optional[float] = None
|
||||
@@ -26,6 +29,7 @@ class NotificationFilters(BaseModel):
|
||||
|
||||
class NotificationSettings(BaseModel):
|
||||
"""Complete notification settings"""
|
||||
|
||||
discord: Optional[DiscordConfig] = None
|
||||
telegram: Optional[TelegramConfig] = None
|
||||
filters: NotificationFilters = NotificationFilters()
|
||||
@@ -33,12 +37,14 @@ class NotificationSettings(BaseModel):
|
||||
|
||||
class NotificationTest(BaseModel):
|
||||
"""Test notification request"""
|
||||
|
||||
service: str # "discord" or "telegram"
|
||||
message: str = "Test notification from Leggen"
|
||||
|
||||
|
||||
class NotificationHistory(BaseModel):
|
||||
"""Notification history entry"""
|
||||
|
||||
id: str
|
||||
service: str
|
||||
message: str
|
||||
|
||||
@@ -6,12 +6,14 @@ from pydantic import BaseModel
|
||||
|
||||
class SyncRequest(BaseModel):
|
||||
"""Request to trigger a sync"""
|
||||
|
||||
account_ids: Optional[list[str]] = None # If None, sync all accounts
|
||||
force: bool = False # Force sync even if recently synced
|
||||
|
||||
|
||||
class SyncStatus(BaseModel):
|
||||
"""Sync operation status"""
|
||||
|
||||
is_running: bool
|
||||
last_sync: Optional[datetime] = None
|
||||
next_sync: Optional[datetime] = None
|
||||
@@ -21,13 +23,12 @@ class SyncStatus(BaseModel):
|
||||
errors: list[str] = []
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
json_encoders = {datetime: lambda v: v.isoformat() if v else None}
|
||||
|
||||
|
||||
class SyncResult(BaseModel):
|
||||
"""Result of a sync operation"""
|
||||
|
||||
success: bool
|
||||
accounts_processed: int
|
||||
transactions_added: int
|
||||
@@ -39,13 +40,12 @@ class SyncResult(BaseModel):
|
||||
completed_at: datetime
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat()
|
||||
}
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
|
||||
|
||||
class SchedulerConfig(BaseModel):
|
||||
"""Scheduler configuration model"""
|
||||
|
||||
enabled: bool = True
|
||||
hour: Optional[int] = 3
|
||||
minute: Optional[int] = 0
|
||||
|
||||
@@ -3,7 +3,12 @@ from fastapi import APIRouter, HTTPException, Query
|
||||
from loguru import logger
|
||||
|
||||
from leggend.api.models.common import APIResponse
|
||||
from leggend.api.models.accounts import AccountDetails, AccountBalance, Transaction, TransactionSummary
|
||||
from leggend.api.models.accounts import (
|
||||
AccountDetails,
|
||||
AccountBalance,
|
||||
Transaction,
|
||||
TransactionSummary,
|
||||
)
|
||||
from leggend.services.gocardless_service import GoCardlessService
|
||||
from leggend.services.database_service import DatabaseService
|
||||
|
||||
@@ -25,21 +30,28 @@ async def get_all_accounts() -> APIResponse:
|
||||
accounts = []
|
||||
for account_id in all_accounts:
|
||||
try:
|
||||
account_details = await gocardless_service.get_account_details(account_id)
|
||||
balances_data = await gocardless_service.get_account_balances(account_id)
|
||||
account_details = await gocardless_service.get_account_details(
|
||||
account_id
|
||||
)
|
||||
balances_data = await gocardless_service.get_account_balances(
|
||||
account_id
|
||||
)
|
||||
|
||||
# Process balances
|
||||
balances = []
|
||||
for balance in balances_data.get("balances", []):
|
||||
balance_amount = balance["balanceAmount"]
|
||||
balances.append(AccountBalance(
|
||||
balances.append(
|
||||
AccountBalance(
|
||||
amount=float(balance_amount["amount"]),
|
||||
currency=balance_amount["currency"],
|
||||
balance_type=balance["balanceType"],
|
||||
last_change_date=balance.get("lastChangeDateTime")
|
||||
))
|
||||
last_change_date=balance.get("lastChangeDateTime"),
|
||||
)
|
||||
)
|
||||
|
||||
accounts.append(AccountDetails(
|
||||
accounts.append(
|
||||
AccountDetails(
|
||||
id=account_details["id"],
|
||||
institution_id=account_details["institution_id"],
|
||||
status=account_details["status"],
|
||||
@@ -48,17 +60,16 @@ async def get_all_accounts() -> APIResponse:
|
||||
currency=account_details.get("currency"),
|
||||
created=account_details["created"],
|
||||
last_accessed=account_details.get("last_accessed"),
|
||||
balances=balances
|
||||
))
|
||||
balances=balances,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get details for account {account_id}: {e}")
|
||||
continue
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=accounts,
|
||||
message=f"Retrieved {len(accounts)} accounts"
|
||||
success=True, data=accounts, message=f"Retrieved {len(accounts)} accounts"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -77,12 +88,14 @@ async def get_account_details(account_id: str) -> APIResponse:
|
||||
balances = []
|
||||
for balance in balances_data.get("balances", []):
|
||||
balance_amount = balance["balanceAmount"]
|
||||
balances.append(AccountBalance(
|
||||
balances.append(
|
||||
AccountBalance(
|
||||
amount=float(balance_amount["amount"]),
|
||||
currency=balance_amount["currency"],
|
||||
balance_type=balance["balanceType"],
|
||||
last_change_date=balance.get("lastChangeDateTime")
|
||||
))
|
||||
last_change_date=balance.get("lastChangeDateTime"),
|
||||
)
|
||||
)
|
||||
|
||||
account = AccountDetails(
|
||||
id=account_details["id"],
|
||||
@@ -93,13 +106,13 @@ async def get_account_details(account_id: str) -> APIResponse:
|
||||
currency=account_details.get("currency"),
|
||||
created=account_details["created"],
|
||||
last_accessed=account_details.get("last_accessed"),
|
||||
balances=balances
|
||||
balances=balances,
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=account,
|
||||
message=f"Account details retrieved for {account_id}"
|
||||
message=f"Account details retrieved for {account_id}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -116,17 +129,19 @@ async def get_account_balances(account_id: str) -> APIResponse:
|
||||
balances = []
|
||||
for balance in balances_data.get("balances", []):
|
||||
balance_amount = balance["balanceAmount"]
|
||||
balances.append(AccountBalance(
|
||||
balances.append(
|
||||
AccountBalance(
|
||||
amount=float(balance_amount["amount"]),
|
||||
currency=balance_amount["currency"],
|
||||
balance_type=balance["balanceType"],
|
||||
last_change_date=balance.get("lastChangeDateTime")
|
||||
))
|
||||
last_change_date=balance.get("lastChangeDateTime"),
|
||||
)
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=balances,
|
||||
message=f"Retrieved {len(balances)} balances for account {account_id}"
|
||||
message=f"Retrieved {len(balances)} balances for account {account_id}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -139,12 +154,16 @@ async def get_account_transactions(
|
||||
account_id: str,
|
||||
limit: Optional[int] = Query(default=100, le=500),
|
||||
offset: Optional[int] = Query(default=0, ge=0),
|
||||
summary_only: bool = Query(default=False, description="Return transaction summaries only")
|
||||
summary_only: bool = Query(
|
||||
default=False, description="Return transaction summaries only"
|
||||
),
|
||||
) -> APIResponse:
|
||||
"""Get transactions for a specific account"""
|
||||
try:
|
||||
account_details = await gocardless_service.get_account_details(account_id)
|
||||
transactions_data = await gocardless_service.get_account_transactions(account_id)
|
||||
transactions_data = await gocardless_service.get_account_transactions(
|
||||
account_id
|
||||
)
|
||||
|
||||
# Process transactions
|
||||
processed_transactions = database_service.process_transactions(
|
||||
@@ -165,7 +184,7 @@ async def get_account_transactions(
|
||||
amount=txn["transactionValue"],
|
||||
currency=txn["transactionCurrency"],
|
||||
status=txn["transactionStatus"],
|
||||
account_id=txn["accountId"]
|
||||
account_id=txn["accountId"],
|
||||
)
|
||||
for txn in paginated_transactions
|
||||
]
|
||||
@@ -183,7 +202,7 @@ async def get_account_transactions(
|
||||
transaction_value=txn["transactionValue"],
|
||||
transaction_currency=txn["transactionCurrency"],
|
||||
transaction_status=txn["transactionStatus"],
|
||||
raw_transaction=txn["rawTransaction"]
|
||||
raw_transaction=txn["rawTransaction"],
|
||||
)
|
||||
for txn in paginated_transactions
|
||||
]
|
||||
@@ -192,9 +211,11 @@ async def get_account_transactions(
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=data,
|
||||
message=f"Retrieved {len(data)} transactions (showing {offset + 1}-{offset + len(data)} of {total_transactions})"
|
||||
message=f"Retrieved {len(data)} transactions (showing {offset + 1}-{offset + len(data)} of {total_transactions})",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get transactions for account {account_id}: {e}")
|
||||
raise HTTPException(status_code=404, detail=f"Failed to get transactions: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Failed to get transactions: {str(e)}"
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ from leggend.api.models.banks import (
|
||||
BankInstitution,
|
||||
BankConnectionRequest,
|
||||
BankRequisition,
|
||||
BankConnectionStatus
|
||||
BankConnectionStatus,
|
||||
)
|
||||
from leggend.services.gocardless_service import GoCardlessService
|
||||
from leggend.utils.gocardless import REQUISITION_STATUS
|
||||
@@ -18,7 +18,7 @@ gocardless_service = GoCardlessService()
|
||||
|
||||
@router.get("/banks/institutions", response_model=APIResponse)
|
||||
async def get_bank_institutions(
|
||||
country: str = Query(default="PT", description="Country code (e.g., PT, ES, FR)")
|
||||
country: str = Query(default="PT", description="Country code (e.g., PT, ES, FR)"),
|
||||
) -> APIResponse:
|
||||
"""Get available bank institutions for a country"""
|
||||
try:
|
||||
@@ -31,7 +31,7 @@ async def get_bank_institutions(
|
||||
bic=inst.get("bic"),
|
||||
transaction_total_days=inst["transaction_total_days"],
|
||||
countries=inst["countries"],
|
||||
logo=inst.get("logo")
|
||||
logo=inst.get("logo"),
|
||||
)
|
||||
for inst in institutions_data
|
||||
]
|
||||
@@ -39,12 +39,14 @@ async def get_bank_institutions(
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=institutions,
|
||||
message=f"Found {len(institutions)} institutions for {country}"
|
||||
message=f"Found {len(institutions)} institutions for {country}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get institutions for {country}: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get institutions: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get institutions: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/banks/connect", response_model=APIResponse)
|
||||
@@ -52,8 +54,7 @@ async def connect_to_bank(request: BankConnectionRequest) -> APIResponse:
|
||||
"""Create a connection to a bank (requisition)"""
|
||||
try:
|
||||
requisition_data = await gocardless_service.create_requisition(
|
||||
request.institution_id,
|
||||
request.redirect_url
|
||||
request.institution_id, request.redirect_url
|
||||
)
|
||||
|
||||
requisition = BankRequisition(
|
||||
@@ -62,18 +63,20 @@ async def connect_to_bank(request: BankConnectionRequest) -> APIResponse:
|
||||
status=requisition_data["status"],
|
||||
created=requisition_data["created"],
|
||||
link=requisition_data["link"],
|
||||
accounts=requisition_data.get("accounts", [])
|
||||
accounts=requisition_data.get("accounts", []),
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=requisition,
|
||||
message=f"Bank connection created. Please visit the link to authorize."
|
||||
message=f"Bank connection created. Please visit the link to authorize.",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to bank {request.institution_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to connect to bank: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to connect to bank: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/banks/status", response_model=APIResponse)
|
||||
@@ -87,25 +90,31 @@ async def get_bank_connections_status() -> APIResponse:
|
||||
status = req["status"]
|
||||
status_display = REQUISITION_STATUS.get(status, "UNKNOWN")
|
||||
|
||||
connections.append(BankConnectionStatus(
|
||||
connections.append(
|
||||
BankConnectionStatus(
|
||||
bank_id=req["institution_id"],
|
||||
bank_name=req["institution_id"], # Could be enhanced with actual bank names
|
||||
bank_name=req[
|
||||
"institution_id"
|
||||
], # Could be enhanced with actual bank names
|
||||
status=status,
|
||||
status_display=status_display,
|
||||
created_at=req["created"],
|
||||
requisition_id=req["id"],
|
||||
accounts_count=len(req.get("accounts", []))
|
||||
))
|
||||
accounts_count=len(req.get("accounts", [])),
|
||||
)
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=connections,
|
||||
message=f"Found {len(connections)} bank connections"
|
||||
message=f"Found {len(connections)} bank connections",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get bank connection status: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get bank status: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get bank status: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/banks/connections/{requisition_id}", response_model=APIResponse)
|
||||
@@ -116,12 +125,14 @@ async def delete_bank_connection(requisition_id: str) -> APIResponse:
|
||||
# For now, return success
|
||||
return APIResponse(
|
||||
success=True,
|
||||
message=f"Bank connection {requisition_id} deleted successfully"
|
||||
message=f"Bank connection {requisition_id} deleted successfully",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete bank connection {requisition_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete connection: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to delete connection: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/banks/countries", response_model=APIResponse)
|
||||
@@ -164,5 +175,5 @@ async def get_supported_countries() -> APIResponse:
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=countries,
|
||||
message="Supported countries retrieved successfully"
|
||||
message="Supported countries retrieved successfully",
|
||||
)
|
||||
@@ -8,7 +8,7 @@ from leggend.api.models.notifications import (
|
||||
NotificationTest,
|
||||
DiscordConfig,
|
||||
TelegramConfig,
|
||||
NotificationFilters
|
||||
NotificationFilters,
|
||||
)
|
||||
from leggend.services.notification_service import NotificationService
|
||||
from leggend.config import config
|
||||
@@ -31,30 +31,36 @@ async def get_notification_settings() -> APIResponse:
|
||||
settings = NotificationSettings(
|
||||
discord=DiscordConfig(
|
||||
webhook="***" if discord_config.get("webhook") else "",
|
||||
enabled=discord_config.get("enabled", True)
|
||||
) if discord_config.get("webhook") else None,
|
||||
enabled=discord_config.get("enabled", True),
|
||||
)
|
||||
if discord_config.get("webhook")
|
||||
else None,
|
||||
telegram=TelegramConfig(
|
||||
token="***" if telegram_config.get("token") else "",
|
||||
chat_id=telegram_config.get("chat_id", 0),
|
||||
enabled=telegram_config.get("enabled", True)
|
||||
) if telegram_config.get("token") else None,
|
||||
enabled=telegram_config.get("enabled", True),
|
||||
)
|
||||
if telegram_config.get("token")
|
||||
else None,
|
||||
filters=NotificationFilters(
|
||||
case_insensitive=filters_config.get("case-insensitive", {}),
|
||||
case_sensitive=filters_config.get("case-sensitive"),
|
||||
amount_threshold=filters_config.get("amount_threshold"),
|
||||
keywords=filters_config.get("keywords", [])
|
||||
)
|
||||
keywords=filters_config.get("keywords", []),
|
||||
),
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=settings,
|
||||
message="Notification settings retrieved successfully"
|
||||
message="Notification settings retrieved successfully",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get notification settings: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get notification settings: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get notification settings: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/notifications/settings", response_model=APIResponse)
|
||||
@@ -67,14 +73,14 @@ async def update_notification_settings(settings: NotificationSettings) -> APIRes
|
||||
if settings.discord:
|
||||
notifications_config["discord"] = {
|
||||
"webhook": settings.discord.webhook,
|
||||
"enabled": settings.discord.enabled
|
||||
"enabled": settings.discord.enabled,
|
||||
}
|
||||
|
||||
if settings.telegram:
|
||||
notifications_config["telegram"] = {
|
||||
"token": settings.telegram.token,
|
||||
"chat_id": settings.telegram.chat_id,
|
||||
"enabled": settings.telegram.enabled
|
||||
"enabled": settings.telegram.enabled,
|
||||
}
|
||||
|
||||
# Update filters config
|
||||
@@ -97,12 +103,14 @@ async def update_notification_settings(settings: NotificationSettings) -> APIRes
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data={"updated": True},
|
||||
message="Notification settings updated successfully"
|
||||
message="Notification settings updated successfully",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update notification settings: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update notification settings: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to update notification settings: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/notifications/test", response_model=APIResponse)
|
||||
@@ -110,25 +118,26 @@ async def test_notification(test_request: NotificationTest) -> APIResponse:
|
||||
"""Send a test notification"""
|
||||
try:
|
||||
success = await notification_service.send_test_notification(
|
||||
test_request.service,
|
||||
test_request.message
|
||||
test_request.service, test_request.message
|
||||
)
|
||||
|
||||
if success:
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data={"sent": True},
|
||||
message=f"Test notification sent to {test_request.service} successfully"
|
||||
message=f"Test notification sent to {test_request.service} successfully",
|
||||
)
|
||||
else:
|
||||
return APIResponse(
|
||||
success=False,
|
||||
message=f"Failed to send test notification to {test_request.service}"
|
||||
message=f"Failed to send test notification to {test_request.service}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send test notification: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to send test notification: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to send test notification: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/notifications/services", response_model=APIResponse)
|
||||
@@ -141,32 +150,36 @@ async def get_notification_services() -> APIResponse:
|
||||
"discord": {
|
||||
"name": "Discord",
|
||||
"enabled": bool(notifications_config.get("discord", {}).get("webhook")),
|
||||
"configured": bool(notifications_config.get("discord", {}).get("webhook")),
|
||||
"active": notifications_config.get("discord", {}).get("enabled", True)
|
||||
"configured": bool(
|
||||
notifications_config.get("discord", {}).get("webhook")
|
||||
),
|
||||
"active": notifications_config.get("discord", {}).get("enabled", True),
|
||||
},
|
||||
"telegram": {
|
||||
"name": "Telegram",
|
||||
"enabled": bool(
|
||||
notifications_config.get("telegram", {}).get("token") and
|
||||
notifications_config.get("telegram", {}).get("chat_id")
|
||||
notifications_config.get("telegram", {}).get("token")
|
||||
and notifications_config.get("telegram", {}).get("chat_id")
|
||||
),
|
||||
"configured": bool(
|
||||
notifications_config.get("telegram", {}).get("token") and
|
||||
notifications_config.get("telegram", {}).get("chat_id")
|
||||
notifications_config.get("telegram", {}).get("token")
|
||||
and notifications_config.get("telegram", {}).get("chat_id")
|
||||
),
|
||||
"active": notifications_config.get("telegram", {}).get("enabled", True)
|
||||
}
|
||||
"active": notifications_config.get("telegram", {}).get("enabled", True),
|
||||
},
|
||||
}
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=services,
|
||||
message="Notification services status retrieved successfully"
|
||||
message="Notification services status retrieved successfully",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get notification services: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get notification services: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get notification services: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/notifications/settings/{service}", response_model=APIResponse)
|
||||
@@ -174,7 +187,9 @@ async def delete_notification_service(service: str) -> APIResponse:
|
||||
"""Delete/disable a notification service"""
|
||||
try:
|
||||
if service not in ["discord", "telegram"]:
|
||||
raise HTTPException(status_code=400, detail="Service must be 'discord' or 'telegram'")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Service must be 'discord' or 'telegram'"
|
||||
)
|
||||
|
||||
notifications_config = config.notifications_config.copy()
|
||||
if service in notifications_config:
|
||||
@@ -184,9 +199,11 @@ async def delete_notification_service(service: str) -> APIResponse:
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data={"deleted": service},
|
||||
message=f"{service.capitalize()} notification service deleted successfully"
|
||||
message=f"{service.capitalize()} notification service deleted successfully",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete notification service {service}: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete notification service: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to delete notification service: {str(e)}"
|
||||
)
|
||||
|
||||
@@ -24,20 +24,19 @@ async def get_sync_status() -> APIResponse:
|
||||
status.next_sync = next_sync_time
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=status,
|
||||
message="Sync status retrieved successfully"
|
||||
success=True, data=status, message="Sync status retrieved successfully"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get sync status: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get sync status: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get sync status: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sync", response_model=APIResponse)
|
||||
async def trigger_sync(
|
||||
background_tasks: BackgroundTasks,
|
||||
sync_request: Optional[SyncRequest] = None
|
||||
background_tasks: BackgroundTasks, sync_request: Optional[SyncRequest] = None
|
||||
) -> APIResponse:
|
||||
"""Trigger a manual sync operation"""
|
||||
try:
|
||||
@@ -46,7 +45,7 @@ async def trigger_sync(
|
||||
if status.is_running and not (sync_request and sync_request.force):
|
||||
return APIResponse(
|
||||
success=False,
|
||||
message="Sync is already running. Use 'force: true' to override."
|
||||
message="Sync is already running. Use 'force: true' to override.",
|
||||
)
|
||||
|
||||
# Determine what to sync
|
||||
@@ -55,21 +54,26 @@ async def trigger_sync(
|
||||
background_tasks.add_task(
|
||||
sync_service.sync_specific_accounts,
|
||||
sync_request.account_ids,
|
||||
sync_request.force if sync_request else False
|
||||
sync_request.force if sync_request else False,
|
||||
)
|
||||
message = (
|
||||
f"Started sync for {len(sync_request.account_ids)} specific accounts"
|
||||
)
|
||||
message = f"Started sync for {len(sync_request.account_ids)} specific accounts"
|
||||
else:
|
||||
# Sync all accounts in background
|
||||
background_tasks.add_task(
|
||||
sync_service.sync_all_accounts,
|
||||
sync_request.force if sync_request else False
|
||||
sync_request.force if sync_request else False,
|
||||
)
|
||||
message = "Started sync for all accounts"
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data={"sync_started": True, "force": sync_request.force if sync_request else False},
|
||||
message=message
|
||||
data={
|
||||
"sync_started": True,
|
||||
"force": sync_request.force if sync_request else False,
|
||||
},
|
||||
message=message,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -83,8 +87,7 @@ async def sync_now(sync_request: Optional[SyncRequest] = None) -> APIResponse:
|
||||
try:
|
||||
if sync_request and sync_request.account_ids:
|
||||
result = await sync_service.sync_specific_accounts(
|
||||
sync_request.account_ids,
|
||||
sync_request.force
|
||||
sync_request.account_ids, sync_request.force
|
||||
)
|
||||
else:
|
||||
result = await sync_service.sync_all_accounts(
|
||||
@@ -94,7 +97,9 @@ async def sync_now(sync_request: Optional[SyncRequest] = None) -> APIResponse:
|
||||
return APIResponse(
|
||||
success=result.success,
|
||||
data=result,
|
||||
message="Sync completed" if result.success else f"Sync failed with {len(result.errors)} errors"
|
||||
message="Sync completed"
|
||||
if result.success
|
||||
else f"Sync failed with {len(result.errors)} errors",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -111,19 +116,25 @@ async def get_scheduler_config() -> APIResponse:
|
||||
|
||||
response_data = {
|
||||
**scheduler_config,
|
||||
"next_scheduled_sync": next_sync_time.isoformat() if next_sync_time else None,
|
||||
"is_running": scheduler.scheduler.running if hasattr(scheduler, 'scheduler') else False
|
||||
"next_scheduled_sync": next_sync_time.isoformat()
|
||||
if next_sync_time
|
||||
else None,
|
||||
"is_running": scheduler.scheduler.running
|
||||
if hasattr(scheduler, "scheduler")
|
||||
else False,
|
||||
}
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=response_data,
|
||||
message="Scheduler configuration retrieved successfully"
|
||||
message="Scheduler configuration retrieved successfully",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get scheduler config: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get scheduler config: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get scheduler config: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/sync/scheduler", response_model=APIResponse)
|
||||
@@ -135,9 +146,13 @@ async def update_scheduler_config(scheduler_config: SchedulerConfig) -> APIRespo
|
||||
try:
|
||||
cron_parts = scheduler_config.cron.split()
|
||||
if len(cron_parts) != 5:
|
||||
raise ValueError("Cron expression must have 5 parts: minute hour day month day_of_week")
|
||||
raise ValueError(
|
||||
"Cron expression must have 5 parts: minute hour day month day_of_week"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid cron expression: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid cron expression: {str(e)}"
|
||||
)
|
||||
|
||||
# Update configuration
|
||||
schedule_data = scheduler_config.dict(exclude_none=True)
|
||||
@@ -149,12 +164,14 @@ async def update_scheduler_config(scheduler_config: SchedulerConfig) -> APIRespo
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=schedule_data,
|
||||
message="Scheduler configuration updated successfully"
|
||||
message="Scheduler configuration updated successfully",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update scheduler config: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update scheduler config: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to update scheduler config: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sync/scheduler/start", response_model=APIResponse)
|
||||
@@ -163,19 +180,15 @@ async def start_scheduler() -> APIResponse:
|
||||
try:
|
||||
if not scheduler.scheduler.running:
|
||||
scheduler.start()
|
||||
return APIResponse(
|
||||
success=True,
|
||||
message="Scheduler started successfully"
|
||||
)
|
||||
return APIResponse(success=True, message="Scheduler started successfully")
|
||||
else:
|
||||
return APIResponse(
|
||||
success=True,
|
||||
message="Scheduler is already running"
|
||||
)
|
||||
return APIResponse(success=True, message="Scheduler is already running")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start scheduler: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start scheduler: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to start scheduler: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sync/scheduler/stop", response_model=APIResponse)
|
||||
@@ -184,16 +197,12 @@ async def stop_scheduler() -> APIResponse:
|
||||
try:
|
||||
if scheduler.scheduler.running:
|
||||
scheduler.shutdown()
|
||||
return APIResponse(
|
||||
success=True,
|
||||
message="Scheduler stopped successfully"
|
||||
)
|
||||
return APIResponse(success=True, message="Scheduler stopped successfully")
|
||||
else:
|
||||
return APIResponse(
|
||||
success=True,
|
||||
message="Scheduler is already stopped"
|
||||
)
|
||||
return APIResponse(success=True, message="Scheduler is already stopped")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop scheduler: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to stop scheduler: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to stop scheduler: {str(e)}"
|
||||
)
|
||||
|
||||
@@ -17,13 +17,25 @@ database_service = DatabaseService()
|
||||
async def get_all_transactions(
|
||||
limit: Optional[int] = Query(default=100, le=500),
|
||||
offset: Optional[int] = Query(default=0, ge=0),
|
||||
summary_only: bool = Query(default=True, description="Return transaction summaries only"),
|
||||
date_from: Optional[str] = Query(default=None, description="Filter from date (YYYY-MM-DD)"),
|
||||
date_to: Optional[str] = Query(default=None, description="Filter to date (YYYY-MM-DD)"),
|
||||
min_amount: Optional[float] = Query(default=None, description="Minimum transaction amount"),
|
||||
max_amount: Optional[float] = Query(default=None, description="Maximum transaction amount"),
|
||||
search: Optional[str] = Query(default=None, description="Search in transaction descriptions"),
|
||||
account_id: Optional[str] = Query(default=None, description="Filter by account ID")
|
||||
summary_only: bool = Query(
|
||||
default=True, description="Return transaction summaries only"
|
||||
),
|
||||
date_from: Optional[str] = Query(
|
||||
default=None, description="Filter from date (YYYY-MM-DD)"
|
||||
),
|
||||
date_to: Optional[str] = Query(
|
||||
default=None, description="Filter to date (YYYY-MM-DD)"
|
||||
),
|
||||
min_amount: Optional[float] = Query(
|
||||
default=None, description="Minimum transaction amount"
|
||||
),
|
||||
max_amount: Optional[float] = Query(
|
||||
default=None, description="Maximum transaction amount"
|
||||
),
|
||||
search: Optional[str] = Query(
|
||||
default=None, description="Search in transaction descriptions"
|
||||
),
|
||||
account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
|
||||
) -> APIResponse:
|
||||
"""Get all transactions across all accounts with filtering options"""
|
||||
try:
|
||||
@@ -46,7 +58,9 @@ async def get_all_transactions(
|
||||
for acc_id in all_accounts:
|
||||
try:
|
||||
account_details = await gocardless_service.get_account_details(acc_id)
|
||||
transactions_data = await gocardless_service.get_account_transactions(acc_id)
|
||||
transactions_data = await gocardless_service.get_account_transactions(
|
||||
acc_id
|
||||
)
|
||||
|
||||
processed_transactions = database_service.process_transactions(
|
||||
acc_id, account_details, transactions_data
|
||||
@@ -64,27 +78,31 @@ async def get_all_transactions(
|
||||
if date_from:
|
||||
from_date = datetime.fromisoformat(date_from)
|
||||
filtered_transactions = [
|
||||
txn for txn in filtered_transactions
|
||||
txn
|
||||
for txn in filtered_transactions
|
||||
if txn["transactionDate"] >= from_date
|
||||
]
|
||||
|
||||
if date_to:
|
||||
to_date = datetime.fromisoformat(date_to)
|
||||
filtered_transactions = [
|
||||
txn for txn in filtered_transactions
|
||||
txn
|
||||
for txn in filtered_transactions
|
||||
if txn["transactionDate"] <= to_date
|
||||
]
|
||||
|
||||
# Amount filters
|
||||
if min_amount is not None:
|
||||
filtered_transactions = [
|
||||
txn for txn in filtered_transactions
|
||||
txn
|
||||
for txn in filtered_transactions
|
||||
if txn["transactionValue"] >= min_amount
|
||||
]
|
||||
|
||||
if max_amount is not None:
|
||||
filtered_transactions = [
|
||||
txn for txn in filtered_transactions
|
||||
txn
|
||||
for txn in filtered_transactions
|
||||
if txn["transactionValue"] <= max_amount
|
||||
]
|
||||
|
||||
@@ -92,15 +110,13 @@ async def get_all_transactions(
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
filtered_transactions = [
|
||||
txn for txn in filtered_transactions
|
||||
txn
|
||||
for txn in filtered_transactions
|
||||
if search_lower in txn["description"].lower()
|
||||
]
|
||||
|
||||
# Sort by date (newest first)
|
||||
filtered_transactions.sort(
|
||||
key=lambda x: x["transactionDate"],
|
||||
reverse=True
|
||||
)
|
||||
filtered_transactions.sort(key=lambda x: x["transactionDate"], reverse=True)
|
||||
|
||||
# Apply pagination
|
||||
total_transactions = len(filtered_transactions)
|
||||
@@ -116,7 +132,7 @@ async def get_all_transactions(
|
||||
amount=txn["transactionValue"],
|
||||
currency=txn["transactionCurrency"],
|
||||
status=txn["transactionStatus"],
|
||||
account_id=txn["accountId"]
|
||||
account_id=txn["accountId"],
|
||||
)
|
||||
for txn in paginated_transactions
|
||||
]
|
||||
@@ -133,7 +149,7 @@ async def get_all_transactions(
|
||||
transaction_value=txn["transactionValue"],
|
||||
transaction_currency=txn["transactionCurrency"],
|
||||
transaction_status=txn["transactionStatus"],
|
||||
raw_transaction=txn["rawTransaction"]
|
||||
raw_transaction=txn["rawTransaction"],
|
||||
)
|
||||
for txn in paginated_transactions
|
||||
]
|
||||
@@ -141,18 +157,20 @@ async def get_all_transactions(
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=data,
|
||||
message=f"Retrieved {len(data)} transactions (showing {offset + 1}-{offset + len(data)} of {total_transactions})"
|
||||
message=f"Retrieved {len(data)} transactions (showing {offset + 1}-{offset + len(data)} of {total_transactions})",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get transactions: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get transactions: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get transactions: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/transactions/stats", response_model=APIResponse)
|
||||
async def get_transaction_stats(
|
||||
days: int = Query(default=30, description="Number of days to include in stats"),
|
||||
account_id: Optional[str] = Query(default=None, description="Filter by account ID")
|
||||
account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
|
||||
) -> APIResponse:
|
||||
"""Get transaction statistics for the last N days"""
|
||||
try:
|
||||
@@ -178,7 +196,9 @@ async def get_transaction_stats(
|
||||
for acc_id in all_accounts:
|
||||
try:
|
||||
account_details = await gocardless_service.get_account_details(acc_id)
|
||||
transactions_data = await gocardless_service.get_account_transactions(acc_id)
|
||||
transactions_data = await gocardless_service.get_account_transactions(
|
||||
acc_id
|
||||
)
|
||||
|
||||
processed_transactions = database_service.process_transactions(
|
||||
acc_id, account_details, transactions_data
|
||||
@@ -191,7 +211,8 @@ async def get_transaction_stats(
|
||||
|
||||
# Filter transactions by date range
|
||||
recent_transactions = [
|
||||
txn for txn in all_transactions
|
||||
txn
|
||||
for txn in all_transactions
|
||||
if start_date <= txn["transactionDate"] <= end_date
|
||||
]
|
||||
|
||||
@@ -210,8 +231,16 @@ async def get_transaction_stats(
|
||||
net_change = total_income - total_expenses
|
||||
|
||||
# Count by status
|
||||
booked_count = len([txn for txn in recent_transactions if txn["transactionStatus"] == "booked"])
|
||||
pending_count = len([txn for txn in recent_transactions if txn["transactionStatus"] == "pending"])
|
||||
booked_count = len(
|
||||
[txn for txn in recent_transactions if txn["transactionStatus"] == "booked"]
|
||||
)
|
||||
pending_count = len(
|
||||
[
|
||||
txn
|
||||
for txn in recent_transactions
|
||||
if txn["transactionStatus"] == "pending"
|
||||
]
|
||||
)
|
||||
|
||||
stats = {
|
||||
"period_days": days,
|
||||
@@ -222,17 +251,23 @@ async def get_transaction_stats(
|
||||
"total_expenses": round(total_expenses, 2),
|
||||
"net_change": round(net_change, 2),
|
||||
"average_transaction": round(
|
||||
sum(txn["transactionValue"] for txn in recent_transactions) / total_transactions, 2
|
||||
) if total_transactions > 0 else 0,
|
||||
"accounts_included": len(all_accounts)
|
||||
sum(txn["transactionValue"] for txn in recent_transactions)
|
||||
/ total_transactions,
|
||||
2,
|
||||
)
|
||||
if total_transactions > 0
|
||||
else 0,
|
||||
"accounts_included": len(all_accounts),
|
||||
}
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=stats,
|
||||
message=f"Transaction statistics for last {days} days"
|
||||
message=f"Transaction statistics for last {days} days",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get transaction stats: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get transaction stats: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get transaction stats: {str(e)}"
|
||||
)
|
||||
|
||||
@@ -4,12 +4,16 @@ from loguru import logger
|
||||
|
||||
from leggend.config import config
|
||||
from leggend.services.sync_service import SyncService
|
||||
from leggend.services.notification_service import NotificationService
|
||||
|
||||
|
||||
class BackgroundScheduler:
|
||||
def __init__(self):
|
||||
self.scheduler = AsyncIOScheduler()
|
||||
self.sync_service = SyncService()
|
||||
self.notification_service = NotificationService()
|
||||
self.max_retries = 3
|
||||
self.retry_delay = 300 # 5 minutes
|
||||
|
||||
def start(self):
|
||||
"""Start the scheduler and configure sync jobs based on configuration"""
|
||||
@@ -20,31 +24,10 @@ class BackgroundScheduler:
|
||||
self.scheduler.start()
|
||||
return
|
||||
|
||||
# Use custom cron expression if provided, otherwise use hour/minute
|
||||
if schedule_config.get("cron"):
|
||||
# Parse custom cron expression (e.g., "0 3 * * *" for daily at 3 AM)
|
||||
try:
|
||||
cron_parts = schedule_config["cron"].split()
|
||||
if len(cron_parts) == 5:
|
||||
minute, hour, day, month, day_of_week = cron_parts
|
||||
trigger = CronTrigger(
|
||||
minute=minute,
|
||||
hour=hour,
|
||||
day=day if day != "*" else None,
|
||||
month=month if month != "*" else None,
|
||||
day_of_week=day_of_week if day_of_week != "*" else None,
|
||||
)
|
||||
else:
|
||||
logger.error(f"Invalid cron expression: {schedule_config['cron']}")
|
||||
# Parse schedule configuration
|
||||
trigger = self._parse_cron_config(schedule_config)
|
||||
if not trigger:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing cron expression: {e}")
|
||||
return
|
||||
else:
|
||||
# Use hour/minute configuration (default: 3:00 AM daily)
|
||||
hour = schedule_config.get("hour", 3)
|
||||
minute = schedule_config.get("minute", 0)
|
||||
trigger = CronTrigger(hour=hour, minute=minute)
|
||||
|
||||
self.scheduler.add_job(
|
||||
self._run_sync,
|
||||
@@ -76,28 +59,9 @@ class BackgroundScheduler:
|
||||
return
|
||||
|
||||
# Configure new schedule
|
||||
if schedule_config.get("cron"):
|
||||
try:
|
||||
cron_parts = schedule_config["cron"].split()
|
||||
if len(cron_parts) == 5:
|
||||
minute, hour, day, month, day_of_week = cron_parts
|
||||
trigger = CronTrigger(
|
||||
minute=minute,
|
||||
hour=hour,
|
||||
day=day if day != "*" else None,
|
||||
month=month if month != "*" else None,
|
||||
day_of_week=day_of_week if day_of_week != "*" else None,
|
||||
)
|
||||
else:
|
||||
logger.error(f"Invalid cron expression: {schedule_config['cron']}")
|
||||
trigger = self._parse_cron_config(schedule_config)
|
||||
if not trigger:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing cron expression: {e}")
|
||||
return
|
||||
else:
|
||||
hour = schedule_config.get("hour", 3)
|
||||
minute = schedule_config.get("minute", 0)
|
||||
trigger = CronTrigger(hour=hour, minute=minute)
|
||||
|
||||
self.scheduler.add_job(
|
||||
self._run_sync,
|
||||
@@ -108,13 +72,90 @@ class BackgroundScheduler:
|
||||
)
|
||||
logger.info(f"Rescheduled sync job with: {trigger}")
|
||||
|
||||
async def _run_sync(self):
|
||||
def _parse_cron_config(self, schedule_config: dict) -> CronTrigger:
|
||||
"""Parse cron configuration and return CronTrigger"""
|
||||
if schedule_config.get("cron"):
|
||||
# Parse custom cron expression (e.g., "0 3 * * *" for daily at 3 AM)
|
||||
try:
|
||||
cron_parts = schedule_config["cron"].split()
|
||||
if len(cron_parts) == 5:
|
||||
minute, hour, day, month, day_of_week = cron_parts
|
||||
return CronTrigger(
|
||||
minute=minute,
|
||||
hour=hour,
|
||||
day=day if day != "*" else None,
|
||||
month=month if month != "*" else None,
|
||||
day_of_week=day_of_week if day_of_week != "*" else None,
|
||||
)
|
||||
else:
|
||||
logger.error(f"Invalid cron expression: {schedule_config['cron']}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing cron expression: {e}")
|
||||
return None
|
||||
else:
|
||||
# Use hour/minute configuration (default: 3:00 AM daily)
|
||||
hour = schedule_config.get("hour", 3)
|
||||
minute = schedule_config.get("minute", 0)
|
||||
return CronTrigger(hour=hour, minute=minute)
|
||||
|
||||
async def _run_sync(self, retry_count: int = 0):
|
||||
"""Run sync with enhanced error handling and retry logic"""
|
||||
try:
|
||||
logger.info("Starting scheduled sync job")
|
||||
await self.sync_service.sync_all_accounts()
|
||||
logger.info("Scheduled sync job completed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Scheduled sync job failed: {e}")
|
||||
logger.error(
|
||||
f"Scheduled sync job failed (attempt {retry_count + 1}/{self.max_retries}): {e}"
|
||||
)
|
||||
|
||||
# Send notification about the failure
|
||||
try:
|
||||
await self.notification_service.send_expiry_notification(
|
||||
{
|
||||
"type": "sync_failure",
|
||||
"error": str(e),
|
||||
"retry_count": retry_count + 1,
|
||||
"max_retries": self.max_retries,
|
||||
}
|
||||
)
|
||||
except Exception as notification_error:
|
||||
logger.error(
|
||||
f"Failed to send failure notification: {notification_error}"
|
||||
)
|
||||
|
||||
# Implement retry logic for transient failures
|
||||
if retry_count < self.max_retries - 1:
|
||||
import datetime
|
||||
|
||||
logger.info(f"Retrying sync job in {self.retry_delay} seconds...")
|
||||
# Schedule a retry
|
||||
retry_time = datetime.datetime.now() + datetime.timedelta(
|
||||
seconds=self.retry_delay
|
||||
)
|
||||
self.scheduler.add_job(
|
||||
self._run_sync,
|
||||
"date",
|
||||
args=[retry_count + 1],
|
||||
id=f"sync_retry_{retry_count + 1}",
|
||||
run_date=retry_time,
|
||||
)
|
||||
else:
|
||||
logger.error("Maximum retries exceeded for sync job")
|
||||
# Send final failure notification
|
||||
try:
|
||||
await self.notification_service.send_expiry_notification(
|
||||
{
|
||||
"type": "sync_final_failure",
|
||||
"error": str(e),
|
||||
"retry_count": retry_count + 1,
|
||||
}
|
||||
)
|
||||
except Exception as notification_error:
|
||||
logger.error(
|
||||
f"Failed to send final failure notification: {notification_error}"
|
||||
)
|
||||
|
||||
def get_next_sync_time(self):
|
||||
"""Get the next scheduled sync time"""
|
||||
|
||||
@@ -24,7 +24,7 @@ class Config:
|
||||
if config_path is None:
|
||||
config_path = os.environ.get(
|
||||
"LEGGEN_CONFIG_FILE",
|
||||
str(Path.home() / ".config" / "leggen" / "config.toml")
|
||||
str(Path.home() / ".config" / "leggen" / "config.toml"),
|
||||
)
|
||||
|
||||
self._config_path = config_path
|
||||
@@ -42,7 +42,9 @@ class Config:
|
||||
|
||||
return self._config
|
||||
|
||||
def save_config(self, config_data: Dict[str, Any] = None, config_path: str = None) -> None:
|
||||
def save_config(
|
||||
self, config_data: Dict[str, Any] = None, config_path: str = None
|
||||
) -> None:
|
||||
"""Save configuration to TOML file"""
|
||||
if config_data is None:
|
||||
config_data = self._config
|
||||
@@ -50,7 +52,7 @@ class Config:
|
||||
if config_path is None:
|
||||
config_path = self._config_path or os.environ.get(
|
||||
"LEGGEN_CONFIG_FILE",
|
||||
str(Path.home() / ".config" / "leggen" / "config.toml")
|
||||
str(Path.home() / ".config" / "leggen" / "config.toml"),
|
||||
)
|
||||
|
||||
# Ensure directory exists
|
||||
@@ -117,7 +119,7 @@ class Config:
|
||||
"enabled": True,
|
||||
"hour": 3,
|
||||
"minute": 0,
|
||||
"cron": None # Optional custom cron expression
|
||||
"cron": None, # Optional custom cron expression
|
||||
}
|
||||
}
|
||||
return self.config.get("scheduler", default_schedule)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from importlib import metadata
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
@@ -36,17 +37,26 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
# Get version dynamically from package metadata
|
||||
try:
|
||||
version = metadata.version("leggen")
|
||||
except metadata.PackageNotFoundError:
|
||||
version = "unknown"
|
||||
|
||||
app = FastAPI(
|
||||
title="Leggend API",
|
||||
description="Open Banking API for Leggen",
|
||||
version="0.6.11",
|
||||
version=version,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:3000", "http://localhost:5173"], # SvelteKit dev servers
|
||||
allow_origins=[
|
||||
"http://localhost:3000",
|
||||
"http://localhost:5173",
|
||||
], # SvelteKit dev servers
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
@@ -60,7 +70,12 @@ def create_app() -> FastAPI:
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "Leggend API is running", "version": "0.6.11"}
|
||||
# Get version dynamically
|
||||
try:
|
||||
version = metadata.version("leggen")
|
||||
except metadata.PackageNotFoundError:
|
||||
version = "unknown"
|
||||
return {"message": "Leggend API is running", "version": version}
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
@@ -71,22 +86,16 @@ def create_app() -> FastAPI:
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Start the Leggend API service")
|
||||
parser.add_argument(
|
||||
"--reload",
|
||||
action="store_true",
|
||||
help="Enable auto-reload for development"
|
||||
"--reload", action="store_true", help="Enable auto-reload for development"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
default="0.0.0.0",
|
||||
help="Host to bind to (default: 0.0.0.0)"
|
||||
"--host", default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port to bind to (default: 8000)"
|
||||
"--port", type=int, default=8000, help="Port to bind to (default: 8000)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@@ -11,7 +11,9 @@ class DatabaseService:
|
||||
self.db_config = config.database_config
|
||||
self.sqlite_enabled = self.db_config.get("sqlite", True)
|
||||
|
||||
async def persist_balance(self, account_id: str, balance_data: Dict[str, Any]) -> None:
|
||||
async def persist_balance(
|
||||
self, account_id: str, balance_data: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Persist account balance data"""
|
||||
if not self.sqlite_enabled:
|
||||
logger.warning("SQLite database disabled, skipping balance persistence")
|
||||
@@ -19,7 +21,9 @@ class DatabaseService:
|
||||
|
||||
await self._persist_balance_sqlite(account_id, balance_data)
|
||||
|
||||
async def persist_transactions(self, account_id: str, transactions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
async def persist_transactions(
|
||||
self, account_id: str, transactions: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Persist transactions and return new transactions"""
|
||||
if not self.sqlite_enabled:
|
||||
logger.warning("SQLite database disabled, skipping transaction persistence")
|
||||
@@ -27,32 +31,48 @@ class DatabaseService:
|
||||
|
||||
return await self._persist_transactions_sqlite(account_id, transactions)
|
||||
|
||||
def process_transactions(self, account_id: str, account_info: Dict[str, Any], transaction_data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
def process_transactions(
|
||||
self,
|
||||
account_id: str,
|
||||
account_info: Dict[str, Any],
|
||||
transaction_data: Dict[str, Any],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Process raw transaction data into standardized format"""
|
||||
transactions = []
|
||||
|
||||
# Process booked transactions
|
||||
for transaction in transaction_data.get("transactions", {}).get("booked", []):
|
||||
processed = self._process_single_transaction(account_id, account_info, transaction, "booked")
|
||||
processed = self._process_single_transaction(
|
||||
account_id, account_info, transaction, "booked"
|
||||
)
|
||||
transactions.append(processed)
|
||||
|
||||
# Process pending transactions
|
||||
for transaction in transaction_data.get("transactions", {}).get("pending", []):
|
||||
processed = self._process_single_transaction(account_id, account_info, transaction, "pending")
|
||||
processed = self._process_single_transaction(
|
||||
account_id, account_info, transaction, "pending"
|
||||
)
|
||||
transactions.append(processed)
|
||||
|
||||
return transactions
|
||||
|
||||
def _process_single_transaction(self, account_id: str, account_info: Dict[str, Any], transaction: Dict[str, Any], status: str) -> Dict[str, Any]:
|
||||
def _process_single_transaction(
|
||||
self,
|
||||
account_id: str,
|
||||
account_info: Dict[str, Any],
|
||||
transaction: Dict[str, Any],
|
||||
status: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Process a single transaction into standardized format"""
|
||||
# Extract dates
|
||||
booked_date = transaction.get("bookingDateTime") or transaction.get("bookingDate")
|
||||
booked_date = transaction.get("bookingDateTime") or transaction.get(
|
||||
"bookingDate"
|
||||
)
|
||||
value_date = transaction.get("valueDateTime") or transaction.get("valueDate")
|
||||
|
||||
if booked_date and value_date:
|
||||
min_date = min(
|
||||
datetime.fromisoformat(booked_date),
|
||||
datetime.fromisoformat(value_date)
|
||||
datetime.fromisoformat(booked_date), datetime.fromisoformat(value_date)
|
||||
)
|
||||
else:
|
||||
min_date = datetime.fromisoformat(booked_date or value_date)
|
||||
@@ -65,7 +85,7 @@ class DatabaseService:
|
||||
# Extract description
|
||||
description = transaction.get(
|
||||
"remittanceInformationUnstructured",
|
||||
",".join(transaction.get("remittanceInformationUnstructuredArray", []))
|
||||
",".join(transaction.get("remittanceInformationUnstructuredArray", [])),
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -81,13 +101,19 @@ class DatabaseService:
|
||||
"rawTransaction": transaction,
|
||||
}
|
||||
|
||||
async def _persist_balance_sqlite(self, account_id: str, balance_data: Dict[str, Any]) -> None:
|
||||
async def _persist_balance_sqlite(
|
||||
self, account_id: str, balance_data: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Persist balance to SQLite - placeholder implementation"""
|
||||
# Would import and use leggen.database.sqlite
|
||||
logger.info(f"Persisting balance to SQLite for account {account_id}")
|
||||
|
||||
async def _persist_transactions_sqlite(self, account_id: str, transactions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
async def _persist_transactions_sqlite(
|
||||
self, account_id: str, transactions: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Persist transactions to SQLite - placeholder implementation"""
|
||||
# Would import and use leggen.database.sqlite
|
||||
logger.info(f"Persisting {len(transactions)} transactions to SQLite for account {account_id}")
|
||||
logger.info(
|
||||
f"Persisting {len(transactions)} transactions to SQLite for account {account_id}"
|
||||
)
|
||||
return transactions # Return new transactions for notifications
|
||||
@@ -12,16 +12,15 @@ from leggend.config import config
|
||||
class GoCardlessService:
|
||||
def __init__(self):
|
||||
self.config = config.gocardless_config
|
||||
self.base_url = self.config.get("url", "https://bankaccountdata.gocardless.com/api/v2")
|
||||
self.base_url = self.config.get(
|
||||
"url", "https://bankaccountdata.gocardless.com/api/v2"
|
||||
)
|
||||
self._token = None
|
||||
|
||||
async def _get_auth_headers(self) -> Dict[str, str]:
|
||||
"""Get authentication headers for GoCardless API"""
|
||||
token = await self._get_token()
|
||||
return {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
||||
|
||||
async def _get_token(self) -> str:
|
||||
"""Get access token for GoCardless API"""
|
||||
@@ -42,7 +41,7 @@ class GoCardlessService:
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/token/refresh/",
|
||||
json={"refresh": auth["refresh"]}
|
||||
json={"refresh": auth["refresh"]},
|
||||
)
|
||||
response.raise_for_status()
|
||||
auth.update(response.json())
|
||||
@@ -95,22 +94,21 @@ class GoCardlessService:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/institutions/",
|
||||
headers=headers,
|
||||
params={"country": country}
|
||||
params={"country": country},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def create_requisition(self, institution_id: str, redirect_url: str) -> Dict[str, Any]:
|
||||
async def create_requisition(
|
||||
self, institution_id: str, redirect_url: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a bank connection requisition"""
|
||||
headers = await self._get_auth_headers()
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/requisitions/",
|
||||
headers=headers,
|
||||
json={
|
||||
"institution_id": institution_id,
|
||||
"redirect": redirect_url
|
||||
}
|
||||
json={"institution_id": institution_id, "redirect": redirect_url},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -120,8 +118,7 @@ class GoCardlessService:
|
||||
headers = await self._get_auth_headers()
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/requisitions/",
|
||||
headers=headers
|
||||
f"{self.base_url}/requisitions/", headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -131,8 +128,7 @@ class GoCardlessService:
|
||||
headers = await self._get_auth_headers()
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/accounts/{account_id}/",
|
||||
headers=headers
|
||||
f"{self.base_url}/accounts/{account_id}/", headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -142,8 +138,7 @@ class GoCardlessService:
|
||||
headers = await self._get_auth_headers()
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/accounts/{account_id}/balances/",
|
||||
headers=headers
|
||||
f"{self.base_url}/accounts/{account_id}/balances/", headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -153,8 +148,7 @@ class GoCardlessService:
|
||||
headers = await self._get_auth_headers()
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/accounts/{account_id}/transactions/",
|
||||
headers=headers
|
||||
f"{self.base_url}/accounts/{account_id}/transactions/", headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -10,7 +10,9 @@ class NotificationService:
|
||||
self.notifications_config = config.notifications_config
|
||||
self.filters_config = config.filters_config
|
||||
|
||||
async def send_transaction_notifications(self, transactions: List[Dict[str, Any]]) -> None:
|
||||
async def send_transaction_notifications(
|
||||
self, transactions: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""Send notifications for new transactions that match filters"""
|
||||
if not self.filters_config:
|
||||
logger.info("No notification filters configured, skipping notifications")
|
||||
@@ -40,7 +42,9 @@ class NotificationService:
|
||||
await self._send_telegram_test(message)
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Notification service '{service}' not enabled or not found")
|
||||
logger.error(
|
||||
f"Notification service '{service}' not enabled or not found"
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send test notification to {service}: {e}")
|
||||
@@ -54,7 +58,9 @@ class NotificationService:
|
||||
if self._is_telegram_enabled():
|
||||
await self._send_telegram_expiry(notification_data)
|
||||
|
||||
def _filter_transactions(self, transactions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
def _filter_transactions(
|
||||
self, transactions: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Filter transactions based on notification criteria"""
|
||||
matching = []
|
||||
filters_case_insensitive = self.filters_config.get("case-insensitive", {})
|
||||
@@ -65,12 +71,14 @@ class NotificationService:
|
||||
# Check case-insensitive filters
|
||||
for filter_name, filter_value in filters_case_insensitive.items():
|
||||
if filter_value.lower() in description:
|
||||
matching.append({
|
||||
matching.append(
|
||||
{
|
||||
"name": transaction["description"],
|
||||
"value": transaction["transactionValue"],
|
||||
"currency": transaction["transactionCurrency"],
|
||||
"date": transaction["transactionDate"],
|
||||
})
|
||||
}
|
||||
)
|
||||
break
|
||||
|
||||
return matching
|
||||
@@ -78,26 +86,34 @@ class NotificationService:
|
||||
def _is_discord_enabled(self) -> bool:
|
||||
"""Check if Discord notifications are enabled"""
|
||||
discord_config = self.notifications_config.get("discord", {})
|
||||
return bool(discord_config.get("webhook") and discord_config.get("enabled", True))
|
||||
return bool(
|
||||
discord_config.get("webhook") and discord_config.get("enabled", True)
|
||||
)
|
||||
|
||||
def _is_telegram_enabled(self) -> bool:
|
||||
"""Check if Telegram notifications are enabled"""
|
||||
telegram_config = self.notifications_config.get("telegram", {})
|
||||
return bool(
|
||||
telegram_config.get("token") and
|
||||
telegram_config.get("chat_id") and
|
||||
telegram_config.get("enabled", True)
|
||||
telegram_config.get("token")
|
||||
and telegram_config.get("chat_id")
|
||||
and telegram_config.get("enabled", True)
|
||||
)
|
||||
|
||||
async def _send_discord_notifications(self, transactions: List[Dict[str, Any]]) -> None:
|
||||
async def _send_discord_notifications(
|
||||
self, transactions: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""Send Discord notifications - placeholder implementation"""
|
||||
# Would import and use leggen.notifications.discord
|
||||
logger.info(f"Sending {len(transactions)} transaction notifications to Discord")
|
||||
|
||||
async def _send_telegram_notifications(self, transactions: List[Dict[str, Any]]) -> None:
|
||||
async def _send_telegram_notifications(
|
||||
self, transactions: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""Send Telegram notifications - placeholder implementation"""
|
||||
# Would import and use leggen.notifications.telegram
|
||||
logger.info(f"Sending {len(transactions)} transaction notifications to Telegram")
|
||||
logger.info(
|
||||
f"Sending {len(transactions)} transaction notifications to Telegram"
|
||||
)
|
||||
|
||||
async def _send_discord_test(self, message: str) -> None:
|
||||
"""Send Discord test notification"""
|
||||
|
||||
@@ -53,7 +53,9 @@ class SyncService:
|
||||
for account_id in all_accounts:
|
||||
try:
|
||||
# Get account details
|
||||
account_details = await self.gocardless.get_account_details(account_id)
|
||||
account_details = await self.gocardless.get_account_details(
|
||||
account_id
|
||||
)
|
||||
|
||||
# Get and save balances
|
||||
balances = await self.gocardless.get_account_balances(account_id)
|
||||
@@ -62,7 +64,9 @@ class SyncService:
|
||||
balances_updated += len(balances.get("balances", []))
|
||||
|
||||
# Get and save transactions
|
||||
transactions = await self.gocardless.get_account_transactions(account_id)
|
||||
transactions = await self.gocardless.get_account_transactions(
|
||||
account_id
|
||||
)
|
||||
if transactions:
|
||||
processed_transactions = self.database.process_transactions(
|
||||
account_id, account_details, transactions
|
||||
@@ -74,7 +78,9 @@ class SyncService:
|
||||
|
||||
# Send notifications for new transactions
|
||||
if new_transactions:
|
||||
await self.notifications.send_transaction_notifications(new_transactions)
|
||||
await self.notifications.send_transaction_notifications(
|
||||
new_transactions
|
||||
)
|
||||
|
||||
accounts_processed += 1
|
||||
self._sync_status.accounts_synced = accounts_processed
|
||||
@@ -100,10 +106,12 @@ class SyncService:
|
||||
duration_seconds=duration,
|
||||
errors=errors,
|
||||
started_at=start_time,
|
||||
completed_at=end_time
|
||||
completed_at=end_time,
|
||||
)
|
||||
|
||||
logger.info(f"Sync completed: {accounts_processed} accounts, {transactions_added} new transactions")
|
||||
logger.info(
|
||||
f"Sync completed: {accounts_processed} accounts, {transactions_added} new transactions"
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
@@ -114,7 +122,9 @@ class SyncService:
|
||||
finally:
|
||||
self._sync_status.is_running = False
|
||||
|
||||
async def sync_specific_accounts(self, account_ids: List[str], force: bool = False) -> SyncResult:
|
||||
async def sync_specific_accounts(
|
||||
self, account_ids: List[str], force: bool = False
|
||||
) -> SyncResult:
|
||||
"""Sync specific accounts"""
|
||||
if self._sync_status.is_running and not force:
|
||||
raise Exception("Sync is already running")
|
||||
@@ -139,7 +149,7 @@ class SyncService:
|
||||
duration_seconds=(end_time - start_time).total_seconds(),
|
||||
errors=[],
|
||||
started_at=start_time,
|
||||
completed_at=end_time
|
||||
completed_at=end_time,
|
||||
)
|
||||
finally:
|
||||
self._sync_status.is_running = False
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Pytest configuration and shared fixtures."""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import json
|
||||
@@ -26,27 +27,20 @@ def mock_config(temp_config_dir):
|
||||
"gocardless": {
|
||||
"key": "test-key",
|
||||
"secret": "test-secret",
|
||||
"url": "https://bankaccountdata.gocardless.com/api/v2"
|
||||
"url": "https://bankaccountdata.gocardless.com/api/v2",
|
||||
},
|
||||
"database": {
|
||||
"sqlite": True
|
||||
},
|
||||
"scheduler": {
|
||||
"sync": {
|
||||
"enabled": True,
|
||||
"hour": 3,
|
||||
"minute": 0
|
||||
}
|
||||
}
|
||||
"database": {"sqlite": True},
|
||||
"scheduler": {"sync": {"enabled": True, "hour": 3, "minute": 0}},
|
||||
}
|
||||
|
||||
config_file = temp_config_dir / "config.toml"
|
||||
with open(config_file, "wb") as f:
|
||||
import tomli_w
|
||||
|
||||
tomli_w.dump(config_data, f)
|
||||
|
||||
# Mock the config path
|
||||
with patch.object(Config, 'load_config') as mock_load:
|
||||
with patch.object(Config, "load_config") as mock_load:
|
||||
mock_load.return_value = config_data
|
||||
config = Config()
|
||||
config._config = config_data
|
||||
@@ -57,10 +51,7 @@ def mock_config(temp_config_dir):
|
||||
@pytest.fixture
|
||||
def mock_auth_token(temp_config_dir):
|
||||
"""Mock GoCardless authentication token."""
|
||||
auth_data = {
|
||||
"access": "mock-access-token",
|
||||
"refresh": "mock-refresh-token"
|
||||
}
|
||||
auth_data = {"access": "mock-access-token", "refresh": "mock-refresh-token"}
|
||||
|
||||
auth_file = temp_config_dir / "auth.json"
|
||||
with open(auth_file, "w") as f:
|
||||
@@ -90,15 +81,15 @@ def sample_bank_data():
|
||||
"name": "Revolut",
|
||||
"bic": "REVOLT21",
|
||||
"transaction_total_days": 90,
|
||||
"countries": ["GB", "LT"]
|
||||
"countries": ["GB", "LT"],
|
||||
},
|
||||
{
|
||||
"id": "BANCOBPI_BBPIPTPL",
|
||||
"name": "Banco BPI",
|
||||
"bic": "BBPIPTPL",
|
||||
"transaction_total_days": 90,
|
||||
"countries": ["PT"]
|
||||
}
|
||||
"countries": ["PT"],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@@ -111,7 +102,7 @@ def sample_account_data():
|
||||
"status": "READY",
|
||||
"iban": "LT313250081177977789",
|
||||
"created": "2024-02-13T23:56:00Z",
|
||||
"last_accessed": "2025-09-01T09:30:00Z"
|
||||
"last_accessed": "2025-09-01T09:30:00Z",
|
||||
}
|
||||
|
||||
|
||||
@@ -125,13 +116,10 @@ def sample_transaction_data():
|
||||
"internalTransactionId": "txn-123",
|
||||
"bookingDate": "2025-09-01",
|
||||
"valueDate": "2025-09-01",
|
||||
"transactionAmount": {
|
||||
"amount": "-10.50",
|
||||
"currency": "EUR"
|
||||
},
|
||||
"remittanceInformationUnstructured": "Coffee Shop Payment"
|
||||
"transactionAmount": {"amount": "-10.50", "currency": "EUR"},
|
||||
"remittanceInformationUnstructured": "Coffee Shop Payment",
|
||||
}
|
||||
],
|
||||
"pending": []
|
||||
"pending": [],
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Tests for accounts API endpoints."""
|
||||
|
||||
import pytest
|
||||
import respx
|
||||
import httpx
|
||||
@@ -10,15 +11,12 @@ class TestAccountsAPI:
|
||||
"""Test account-related API endpoints."""
|
||||
|
||||
@respx.mock
|
||||
def test_get_all_accounts_success(self, api_client, mock_config, mock_auth_token, sample_account_data):
|
||||
def test_get_all_accounts_success(
|
||||
self, api_client, mock_config, mock_auth_token, sample_account_data
|
||||
):
|
||||
"""Test successful retrieval of all accounts."""
|
||||
requisitions_data = {
|
||||
"results": [
|
||||
{
|
||||
"id": "req-123",
|
||||
"accounts": ["test-account-123"]
|
||||
}
|
||||
]
|
||||
"results": [{"id": "req-123", "accounts": ["test-account-123"]}]
|
||||
}
|
||||
|
||||
balances_data = {
|
||||
@@ -26,28 +24,30 @@ class TestAccountsAPI:
|
||||
{
|
||||
"balanceAmount": {"amount": "100.50", "currency": "EUR"},
|
||||
"balanceType": "interimAvailable",
|
||||
"lastChangeDateTime": "2025-09-01T09:30:00Z"
|
||||
"lastChangeDateTime": "2025-09-01T09:30:00Z",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Mock GoCardless token creation
|
||||
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
||||
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"})
|
||||
return_value=httpx.Response(
|
||||
200, json={"access": "test-token", "refresh": "test-refresh"}
|
||||
)
|
||||
)
|
||||
|
||||
# Mock GoCardless API calls
|
||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/requisitions/").mock(
|
||||
return_value=httpx.Response(200, json=requisitions_data)
|
||||
)
|
||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/").mock(
|
||||
return_value=httpx.Response(200, json=sample_account_data)
|
||||
)
|
||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/").mock(
|
||||
return_value=httpx.Response(200, json=balances_data)
|
||||
)
|
||||
respx.get(
|
||||
"https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/"
|
||||
).mock(return_value=httpx.Response(200, json=sample_account_data))
|
||||
respx.get(
|
||||
"https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/"
|
||||
).mock(return_value=httpx.Response(200, json=balances_data))
|
||||
|
||||
with patch('leggend.config.config', mock_config):
|
||||
with patch("leggend.config.config", mock_config):
|
||||
response = api_client.get("/api/v1/accounts")
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -61,31 +61,35 @@ class TestAccountsAPI:
|
||||
assert account["balances"][0]["amount"] == 100.50
|
||||
|
||||
@respx.mock
|
||||
def test_get_account_details_success(self, api_client, mock_config, mock_auth_token, sample_account_data):
|
||||
def test_get_account_details_success(
|
||||
self, api_client, mock_config, mock_auth_token, sample_account_data
|
||||
):
|
||||
"""Test successful retrieval of specific account details."""
|
||||
balances_data = {
|
||||
"balances": [
|
||||
{
|
||||
"balanceAmount": {"amount": "250.75", "currency": "EUR"},
|
||||
"balanceType": "interimAvailable"
|
||||
"balanceType": "interimAvailable",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Mock GoCardless token creation
|
||||
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
||||
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"})
|
||||
return_value=httpx.Response(
|
||||
200, json={"access": "test-token", "refresh": "test-refresh"}
|
||||
)
|
||||
)
|
||||
|
||||
# Mock GoCardless API calls
|
||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/").mock(
|
||||
return_value=httpx.Response(200, json=sample_account_data)
|
||||
)
|
||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/").mock(
|
||||
return_value=httpx.Response(200, json=balances_data)
|
||||
)
|
||||
respx.get(
|
||||
"https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/"
|
||||
).mock(return_value=httpx.Response(200, json=sample_account_data))
|
||||
respx.get(
|
||||
"https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/"
|
||||
).mock(return_value=httpx.Response(200, json=balances_data))
|
||||
|
||||
with patch('leggend.config.config', mock_config):
|
||||
with patch("leggend.config.config", mock_config):
|
||||
response = api_client.get("/api/v1/accounts/test-account-123")
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -97,33 +101,37 @@ class TestAccountsAPI:
|
||||
assert len(account["balances"]) == 1
|
||||
|
||||
@respx.mock
|
||||
def test_get_account_balances_success(self, api_client, mock_config, mock_auth_token):
|
||||
def test_get_account_balances_success(
|
||||
self, api_client, mock_config, mock_auth_token
|
||||
):
|
||||
"""Test successful retrieval of account balances."""
|
||||
balances_data = {
|
||||
"balances": [
|
||||
{
|
||||
"balanceAmount": {"amount": "1000.00", "currency": "EUR"},
|
||||
"balanceType": "interimAvailable",
|
||||
"lastChangeDateTime": "2025-09-01T10:00:00Z"
|
||||
"lastChangeDateTime": "2025-09-01T10:00:00Z",
|
||||
},
|
||||
{
|
||||
"balanceAmount": {"amount": "950.00", "currency": "EUR"},
|
||||
"balanceType": "expected"
|
||||
}
|
||||
"balanceType": "expected",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
# Mock GoCardless token creation
|
||||
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
||||
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"})
|
||||
return_value=httpx.Response(
|
||||
200, json={"access": "test-token", "refresh": "test-refresh"}
|
||||
)
|
||||
)
|
||||
|
||||
# Mock GoCardless API
|
||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/").mock(
|
||||
return_value=httpx.Response(200, json=balances_data)
|
||||
)
|
||||
respx.get(
|
||||
"https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/"
|
||||
).mock(return_value=httpx.Response(200, json=balances_data))
|
||||
|
||||
with patch('leggend.config.config', mock_config):
|
||||
with patch("leggend.config.config", mock_config):
|
||||
response = api_client.get("/api/v1/accounts/test-account-123/balances")
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -135,23 +143,34 @@ class TestAccountsAPI:
|
||||
assert data["data"][0]["balance_type"] == "interimAvailable"
|
||||
|
||||
@respx.mock
|
||||
def test_get_account_transactions_success(self, api_client, mock_config, mock_auth_token, sample_account_data, sample_transaction_data):
|
||||
def test_get_account_transactions_success(
|
||||
self,
|
||||
api_client,
|
||||
mock_config,
|
||||
mock_auth_token,
|
||||
sample_account_data,
|
||||
sample_transaction_data,
|
||||
):
|
||||
"""Test successful retrieval of account transactions."""
|
||||
# Mock GoCardless token creation
|
||||
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
||||
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"})
|
||||
return_value=httpx.Response(
|
||||
200, json={"access": "test-token", "refresh": "test-refresh"}
|
||||
)
|
||||
)
|
||||
|
||||
# Mock GoCardless API calls
|
||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/").mock(
|
||||
return_value=httpx.Response(200, json=sample_account_data)
|
||||
)
|
||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/transactions/").mock(
|
||||
return_value=httpx.Response(200, json=sample_transaction_data)
|
||||
)
|
||||
respx.get(
|
||||
"https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/"
|
||||
).mock(return_value=httpx.Response(200, json=sample_account_data))
|
||||
respx.get(
|
||||
"https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/transactions/"
|
||||
).mock(return_value=httpx.Response(200, json=sample_transaction_data))
|
||||
|
||||
with patch('leggend.config.config', mock_config):
|
||||
response = api_client.get("/api/v1/accounts/test-account-123/transactions?summary_only=true")
|
||||
with patch("leggend.config.config", mock_config):
|
||||
response = api_client.get(
|
||||
"/api/v1/accounts/test-account-123/transactions?summary_only=true"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
@@ -165,23 +184,34 @@ class TestAccountsAPI:
|
||||
assert transaction["description"] == "Coffee Shop Payment"
|
||||
|
||||
@respx.mock
|
||||
def test_get_account_transactions_full_details(self, api_client, mock_config, mock_auth_token, sample_account_data, sample_transaction_data):
|
||||
def test_get_account_transactions_full_details(
|
||||
self,
|
||||
api_client,
|
||||
mock_config,
|
||||
mock_auth_token,
|
||||
sample_account_data,
|
||||
sample_transaction_data,
|
||||
):
|
||||
"""Test retrieval of full transaction details."""
|
||||
# Mock GoCardless token creation
|
||||
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
||||
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"})
|
||||
return_value=httpx.Response(
|
||||
200, json={"access": "test-token", "refresh": "test-refresh"}
|
||||
)
|
||||
)
|
||||
|
||||
# Mock GoCardless API calls
|
||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/").mock(
|
||||
return_value=httpx.Response(200, json=sample_account_data)
|
||||
)
|
||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/transactions/").mock(
|
||||
return_value=httpx.Response(200, json=sample_transaction_data)
|
||||
)
|
||||
respx.get(
|
||||
"https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/"
|
||||
).mock(return_value=httpx.Response(200, json=sample_account_data))
|
||||
respx.get(
|
||||
"https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/transactions/"
|
||||
).mock(return_value=httpx.Response(200, json=sample_transaction_data))
|
||||
|
||||
with patch('leggend.config.config', mock_config):
|
||||
response = api_client.get("/api/v1/accounts/test-account-123/transactions?summary_only=false")
|
||||
with patch("leggend.config.config", mock_config):
|
||||
response = api_client.get(
|
||||
"/api/v1/accounts/test-account-123/transactions?summary_only=false"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
@@ -200,14 +230,18 @@ class TestAccountsAPI:
|
||||
with respx.mock:
|
||||
# Mock GoCardless token creation
|
||||
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
||||
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"})
|
||||
return_value=httpx.Response(
|
||||
200, json={"access": "test-token", "refresh": "test-refresh"}
|
||||
)
|
||||
)
|
||||
|
||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/nonexistent/").mock(
|
||||
respx.get(
|
||||
"https://bankaccountdata.gocardless.com/api/v2/accounts/nonexistent/"
|
||||
).mock(
|
||||
return_value=httpx.Response(404, json={"detail": "Account not found"})
|
||||
)
|
||||
|
||||
with patch('leggend.config.config', mock_config):
|
||||
with patch("leggend.config.config", mock_config):
|
||||
response = api_client.get("/api/v1/accounts/nonexistent")
|
||||
|
||||
assert response.status_code == 404
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Tests for banks API endpoints."""
|
||||
|
||||
import pytest
|
||||
import respx
|
||||
import httpx
|
||||
@@ -12,11 +13,15 @@ class TestBanksAPI:
|
||||
"""Test bank-related API endpoints."""
|
||||
|
||||
@respx.mock
|
||||
def test_get_institutions_success(self, api_client, mock_config, mock_auth_token, sample_bank_data):
|
||||
def test_get_institutions_success(
|
||||
self, api_client, mock_config, mock_auth_token, sample_bank_data
|
||||
):
|
||||
"""Test successful retrieval of bank institutions."""
|
||||
# Mock GoCardless token creation/refresh
|
||||
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
||||
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"})
|
||||
return_value=httpx.Response(
|
||||
200, json={"access": "test-token", "refresh": "test-refresh"}
|
||||
)
|
||||
)
|
||||
|
||||
# Mock GoCardless institutions API
|
||||
@@ -24,7 +29,7 @@ class TestBanksAPI:
|
||||
return_value=httpx.Response(200, json=sample_bank_data)
|
||||
)
|
||||
|
||||
with patch('leggend.config.config', mock_config):
|
||||
with patch("leggend.config.config", mock_config):
|
||||
response = api_client.get("/api/v1/banks/institutions?country=PT")
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -39,7 +44,9 @@ class TestBanksAPI:
|
||||
"""Test institutions endpoint with invalid country code."""
|
||||
# Mock GoCardless token creation
|
||||
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
||||
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"})
|
||||
return_value=httpx.Response(
|
||||
200, json={"access": "test-token", "refresh": "test-refresh"}
|
||||
)
|
||||
)
|
||||
|
||||
# Mock empty institutions response for invalid country
|
||||
@@ -47,7 +54,7 @@ class TestBanksAPI:
|
||||
return_value=httpx.Response(200, json=[])
|
||||
)
|
||||
|
||||
with patch('leggend.config.config', mock_config):
|
||||
with patch("leggend.config.config", mock_config):
|
||||
response = api_client.get("/api/v1/banks/institutions?country=XX")
|
||||
|
||||
# Should still work but return empty or filtered results
|
||||
@@ -61,12 +68,14 @@ class TestBanksAPI:
|
||||
"institution_id": "REVOLUT_REVOLT21",
|
||||
"status": "CR",
|
||||
"created": "2025-09-02T00:00:00Z",
|
||||
"link": "https://example.com/auth"
|
||||
"link": "https://example.com/auth",
|
||||
}
|
||||
|
||||
# Mock GoCardless token creation
|
||||
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
||||
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"})
|
||||
return_value=httpx.Response(
|
||||
200, json={"access": "test-token", "refresh": "test-refresh"}
|
||||
)
|
||||
)
|
||||
|
||||
# Mock GoCardless requisitions API
|
||||
@@ -76,10 +85,10 @@ class TestBanksAPI:
|
||||
|
||||
request_data = {
|
||||
"institution_id": "REVOLUT_REVOLT21",
|
||||
"redirect_url": "http://localhost:8000/"
|
||||
"redirect_url": "http://localhost:8000/",
|
||||
}
|
||||
|
||||
with patch('leggend.config.config', mock_config):
|
||||
with patch("leggend.config.config", mock_config):
|
||||
response = api_client.post("/api/v1/banks/connect", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -98,14 +107,16 @@ class TestBanksAPI:
|
||||
"institution_id": "REVOLUT_REVOLT21",
|
||||
"status": "LN",
|
||||
"created": "2025-09-02T00:00:00Z",
|
||||
"accounts": ["acc-123"]
|
||||
"accounts": ["acc-123"],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Mock GoCardless token creation
|
||||
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
||||
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"})
|
||||
return_value=httpx.Response(
|
||||
200, json={"access": "test-token", "refresh": "test-refresh"}
|
||||
)
|
||||
)
|
||||
|
||||
# Mock GoCardless requisitions API
|
||||
@@ -113,7 +124,7 @@ class TestBanksAPI:
|
||||
return_value=httpx.Response(200, json=requisitions_data)
|
||||
)
|
||||
|
||||
with patch('leggend.config.config', mock_config):
|
||||
with patch("leggend.config.config", mock_config):
|
||||
response = api_client.get("/api/v1/banks/status")
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -146,7 +157,7 @@ class TestBanksAPI:
|
||||
return_value=httpx.Response(401, json={"detail": "Invalid credentials"})
|
||||
)
|
||||
|
||||
with patch('leggend.config.config', mock_config):
|
||||
with patch("leggend.config.config", mock_config):
|
||||
response = api_client.get("/api/v1/banks/institutions")
|
||||
|
||||
assert response.status_code == 500
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Tests for CLI API client."""
|
||||
|
||||
import pytest
|
||||
import requests_mock
|
||||
from unittest.mock import patch
|
||||
@@ -37,7 +38,7 @@ class TestLeggendAPIClient:
|
||||
api_response = {
|
||||
"success": True,
|
||||
"data": sample_bank_data,
|
||||
"message": "Found 2 institutions for PT"
|
||||
"message": "Found 2 institutions for PT",
|
||||
}
|
||||
|
||||
with requests_mock.Mocker() as m:
|
||||
@@ -54,7 +55,7 @@ class TestLeggendAPIClient:
|
||||
api_response = {
|
||||
"success": True,
|
||||
"data": [sample_account_data],
|
||||
"message": "Retrieved 1 accounts"
|
||||
"message": "Retrieved 1 accounts",
|
||||
}
|
||||
|
||||
with requests_mock.Mocker() as m:
|
||||
@@ -71,7 +72,7 @@ class TestLeggendAPIClient:
|
||||
api_response = {
|
||||
"success": True,
|
||||
"data": {"sync_started": True, "force": False},
|
||||
"message": "Started sync for all accounts"
|
||||
"message": "Started sync for all accounts",
|
||||
}
|
||||
|
||||
with requests_mock.Mocker() as m:
|
||||
@@ -92,8 +93,11 @@ class TestLeggendAPIClient:
|
||||
client = LeggendAPIClient("http://localhost:8000")
|
||||
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get("http://localhost:8000/api/v1/accounts", status_code=500,
|
||||
json={"detail": "Internal server error"})
|
||||
m.get(
|
||||
"http://localhost:8000/api/v1/accounts",
|
||||
status_code=500,
|
||||
json={"detail": "Internal server error"},
|
||||
)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
client.get_accounts()
|
||||
@@ -107,7 +111,7 @@ class TestLeggendAPIClient:
|
||||
|
||||
def test_environment_variable_url(self):
|
||||
"""Test using environment variable for API URL."""
|
||||
with patch.dict('os.environ', {'LEGGEND_API_URL': 'http://env-host:7000'}):
|
||||
with patch.dict("os.environ", {"LEGGEND_API_URL": "http://env-host:7000"}):
|
||||
client = LeggendAPIClient()
|
||||
assert client.base_url == "http://env-host:7000"
|
||||
|
||||
@@ -118,7 +122,7 @@ class TestLeggendAPIClient:
|
||||
api_response = {
|
||||
"success": True,
|
||||
"data": {"sync_started": True, "force": True},
|
||||
"message": "Started sync for 2 specific accounts"
|
||||
"message": "Started sync for 2 specific accounts",
|
||||
}
|
||||
|
||||
with requests_mock.Mocker() as m:
|
||||
@@ -138,8 +142,8 @@ class TestLeggendAPIClient:
|
||||
"enabled": True,
|
||||
"hour": 3,
|
||||
"minute": 0,
|
||||
"next_scheduled_sync": "2025-09-03T03:00:00Z"
|
||||
}
|
||||
"next_scheduled_sync": "2025-09-03T03:00:00Z",
|
||||
},
|
||||
}
|
||||
|
||||
with requests_mock.Mocker() as m:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Tests for configuration management."""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
@@ -23,16 +24,15 @@ class TestConfig:
|
||||
"gocardless": {
|
||||
"key": "test-key",
|
||||
"secret": "test-secret",
|
||||
"url": "https://test.example.com"
|
||||
"url": "https://test.example.com",
|
||||
},
|
||||
"database": {
|
||||
"sqlite": True
|
||||
}
|
||||
"database": {"sqlite": True},
|
||||
}
|
||||
|
||||
config_file = temp_config_dir / "config.toml"
|
||||
with open(config_file, "wb") as f:
|
||||
import tomli_w
|
||||
|
||||
tomli_w.dump(config_data, f)
|
||||
|
||||
config = Config()
|
||||
@@ -56,12 +56,7 @@ class TestConfig:
|
||||
|
||||
def test_save_config_success(self, temp_config_dir):
|
||||
"""Test successful configuration saving."""
|
||||
config_data = {
|
||||
"gocardless": {
|
||||
"key": "new-key",
|
||||
"secret": "new-secret"
|
||||
}
|
||||
}
|
||||
config_data = {"gocardless": {"key": "new-key", "secret": "new-secret"}}
|
||||
|
||||
config_file = temp_config_dir / "new_config.toml"
|
||||
config = Config()
|
||||
@@ -73,6 +68,7 @@ class TestConfig:
|
||||
assert config_file.exists()
|
||||
|
||||
import tomllib
|
||||
|
||||
with open(config_file, "rb") as f:
|
||||
saved_data = tomllib.load(f)
|
||||
|
||||
@@ -82,12 +78,13 @@ class TestConfig:
|
||||
"""Test updating configuration values."""
|
||||
initial_config = {
|
||||
"gocardless": {"key": "old-key"},
|
||||
"database": {"sqlite": True}
|
||||
"database": {"sqlite": True},
|
||||
}
|
||||
|
||||
config_file = temp_config_dir / "config.toml"
|
||||
with open(config_file, "wb") as f:
|
||||
import tomli_w
|
||||
|
||||
tomli_w.dump(initial_config, f)
|
||||
|
||||
config = Config()
|
||||
@@ -100,19 +97,19 @@ class TestConfig:
|
||||
|
||||
# Verify it was saved to file
|
||||
import tomllib
|
||||
|
||||
with open(config_file, "rb") as f:
|
||||
saved_data = tomllib.load(f)
|
||||
assert saved_data["gocardless"]["key"] == "new-key"
|
||||
|
||||
def test_update_section_success(self, temp_config_dir):
|
||||
"""Test updating entire configuration section."""
|
||||
initial_config = {
|
||||
"database": {"sqlite": True}
|
||||
}
|
||||
initial_config = {"database": {"sqlite": True}}
|
||||
|
||||
config_file = temp_config_dir / "config.toml"
|
||||
with open(config_file, "wb") as f:
|
||||
import tomli_w
|
||||
|
||||
tomli_w.dump(initial_config, f)
|
||||
|
||||
config = Config()
|
||||
@@ -144,7 +141,7 @@ class TestConfig:
|
||||
"enabled": False,
|
||||
"hour": 6,
|
||||
"minute": 30,
|
||||
"cron": "0 6 * * 1-5"
|
||||
"cron": "0 6 * * 1-5",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -161,11 +158,13 @@ class TestConfig:
|
||||
|
||||
def test_environment_variable_config_path(self):
|
||||
"""Test using environment variable for config path."""
|
||||
with patch.dict('os.environ', {'LEGGEN_CONFIG_FILE': '/custom/path/config.toml'}):
|
||||
with patch.dict(
|
||||
"os.environ", {"LEGGEN_CONFIG_FILE": "/custom/path/config.toml"}
|
||||
):
|
||||
config = Config()
|
||||
config._config = None
|
||||
|
||||
with patch('builtins.open', side_effect=FileNotFoundError):
|
||||
with patch("builtins.open", side_effect=FileNotFoundError):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
config.load_config()
|
||||
|
||||
@@ -174,7 +173,7 @@ class TestConfig:
|
||||
custom_config = {
|
||||
"notifications": {
|
||||
"discord": {"webhook": "https://discord.webhook", "enabled": True},
|
||||
"telegram": {"token": "bot-token", "chat_id": 123}
|
||||
"telegram": {"token": "bot-token", "chat_id": 123},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,7 +189,7 @@ class TestConfig:
|
||||
custom_config = {
|
||||
"filters": {
|
||||
"case-insensitive": {"salary": "SALARY", "bills": "BILL"},
|
||||
"amount_threshold": 100.0
|
||||
"amount_threshold": 100.0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Tests for background scheduler."""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||||
@@ -16,22 +17,18 @@ class TestBackgroundScheduler:
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Mock configuration for scheduler tests."""
|
||||
return {
|
||||
"sync": {
|
||||
"enabled": True,
|
||||
"hour": 3,
|
||||
"minute": 0,
|
||||
"cron": None
|
||||
}
|
||||
}
|
||||
return {"sync": {"enabled": True, "hour": 3, "minute": 0, "cron": None}}
|
||||
|
||||
@pytest.fixture
|
||||
def scheduler(self):
|
||||
"""Create scheduler instance for testing."""
|
||||
with patch('leggend.background.scheduler.SyncService'), \
|
||||
patch('leggend.background.scheduler.config') as mock_config:
|
||||
|
||||
mock_config.scheduler_config = {"sync": {"enabled": True, "hour": 3, "minute": 0}}
|
||||
with (
|
||||
patch("leggend.background.scheduler.SyncService"),
|
||||
patch("leggend.background.scheduler.config") as mock_config,
|
||||
):
|
||||
mock_config.scheduler_config = {
|
||||
"sync": {"enabled": True, "hour": 3, "minute": 0}
|
||||
}
|
||||
|
||||
# Create scheduler and replace its AsyncIO scheduler with a mock
|
||||
scheduler = BackgroundScheduler()
|
||||
@@ -43,7 +40,7 @@ class TestBackgroundScheduler:
|
||||
|
||||
def test_scheduler_start_default_config(self, scheduler, mock_config):
|
||||
"""Test starting scheduler with default configuration."""
|
||||
with patch('leggend.config.config') as mock_config_obj:
|
||||
with patch("leggend.config.config") as mock_config_obj:
|
||||
mock_config_obj.scheduler_config = mock_config
|
||||
|
||||
# Mock the job that gets added
|
||||
@@ -60,13 +57,12 @@ class TestBackgroundScheduler:
|
||||
|
||||
def test_scheduler_start_disabled(self, scheduler):
|
||||
"""Test scheduler behavior when sync is disabled."""
|
||||
disabled_config = {
|
||||
"sync": {"enabled": False}
|
||||
}
|
||||
|
||||
with patch.object(scheduler, 'scheduler') as mock_scheduler, \
|
||||
patch('leggend.background.scheduler.config') as mock_config_obj:
|
||||
disabled_config = {"sync": {"enabled": False}}
|
||||
|
||||
with (
|
||||
patch.object(scheduler, "scheduler") as mock_scheduler,
|
||||
patch("leggend.background.scheduler.config") as mock_config_obj,
|
||||
):
|
||||
mock_config_obj.scheduler_config = disabled_config
|
||||
mock_scheduler.running = False
|
||||
|
||||
@@ -82,11 +78,11 @@ class TestBackgroundScheduler:
|
||||
cron_config = {
|
||||
"sync": {
|
||||
"enabled": True,
|
||||
"cron": "0 6 * * 1-5" # 6 AM on weekdays
|
||||
"cron": "0 6 * * 1-5", # 6 AM on weekdays
|
||||
}
|
||||
}
|
||||
|
||||
with patch('leggend.config.config') as mock_config_obj:
|
||||
with patch("leggend.config.config") as mock_config_obj:
|
||||
mock_config_obj.scheduler_config = cron_config
|
||||
|
||||
scheduler.start()
|
||||
@@ -96,20 +92,16 @@ class TestBackgroundScheduler:
|
||||
scheduler.scheduler.add_job.assert_called_once()
|
||||
# Verify job was added with correct ID
|
||||
call_args = scheduler.scheduler.add_job.call_args
|
||||
assert call_args.kwargs['id'] == 'daily_sync'
|
||||
assert call_args.kwargs["id"] == "daily_sync"
|
||||
|
||||
def test_scheduler_start_invalid_cron(self, scheduler):
|
||||
"""Test handling of invalid cron expressions."""
|
||||
invalid_cron_config = {
|
||||
"sync": {
|
||||
"enabled": True,
|
||||
"cron": "invalid cron"
|
||||
}
|
||||
}
|
||||
|
||||
with patch.object(scheduler, 'scheduler') as mock_scheduler, \
|
||||
patch('leggend.background.scheduler.config') as mock_config_obj:
|
||||
invalid_cron_config = {"sync": {"enabled": True, "cron": "invalid cron"}}
|
||||
|
||||
with (
|
||||
patch.object(scheduler, "scheduler") as mock_scheduler,
|
||||
patch("leggend.background.scheduler.config") as mock_config_obj,
|
||||
):
|
||||
mock_config_obj.scheduler_config = invalid_cron_config
|
||||
mock_scheduler.running = False
|
||||
|
||||
@@ -133,11 +125,7 @@ class TestBackgroundScheduler:
|
||||
scheduler.scheduler.running = True
|
||||
|
||||
# Reschedule with new config
|
||||
new_config = {
|
||||
"enabled": True,
|
||||
"hour": 6,
|
||||
"minute": 30
|
||||
}
|
||||
new_config = {"enabled": True, "hour": 6, "minute": 30}
|
||||
|
||||
scheduler.reschedule_sync(new_config)
|
||||
|
||||
@@ -202,10 +190,10 @@ class TestBackgroundScheduler:
|
||||
|
||||
def test_scheduler_job_max_instances(self, scheduler, mock_config):
|
||||
"""Test that sync jobs have max_instances=1."""
|
||||
with patch('leggend.config.config') as mock_config_obj:
|
||||
with patch("leggend.config.config") as mock_config_obj:
|
||||
mock_config_obj.scheduler_config = mock_config
|
||||
scheduler.start()
|
||||
|
||||
# Verify add_job was called with max_instances=1
|
||||
call_args = scheduler.scheduler.add_job.call_args
|
||||
assert call_args.kwargs['max_instances'] == 1
|
||||
assert call_args.kwargs["max_instances"] == 1
|
||||
|
||||
Reference in New Issue
Block a user