mirror of
https://github.com/elisiariocouto/leggen.git
synced 2025-12-13 11:22:21 +00:00
chore: Implement code review suggestions and format code.
This commit is contained in:
committed by
Elisiário Couto
parent
47164e8546
commit
de3da84dff
@@ -1,32 +0,0 @@
|
|||||||
FROM python:3.13-alpine AS builder
|
|
||||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
|
||||||
--mount=type=bind,source=uv.lock,target=uv.lock \
|
|
||||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
|
||||||
uv sync --frozen --no-install-project --no-editable
|
|
||||||
|
|
||||||
COPY . /app
|
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
|
||||||
uv sync --frozen --no-editable --no-group dev
|
|
||||||
|
|
||||||
FROM python:3.13-alpine
|
|
||||||
|
|
||||||
LABEL org.opencontainers.image.source="https://github.com/elisiariocouto/leggen"
|
|
||||||
LABEL org.opencontainers.image.authors="Elisiário Couto <elisiario@couto.io>"
|
|
||||||
LABEL org.opencontainers.image.licenses="MIT"
|
|
||||||
LABEL org.opencontainers.image.title="leggend"
|
|
||||||
LABEL org.opencontainers.image.description="Leggen API service"
|
|
||||||
LABEL org.opencontainers.image.url="https://github.com/elisiariocouto/leggen"
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
ENV PATH="/app/.venv/bin:$PATH"
|
|
||||||
|
|
||||||
COPY --from=builder --chown=app:app /app/.venv /app/.venv
|
|
||||||
|
|
||||||
EXPOSE 8000
|
|
||||||
|
|
||||||
ENTRYPOINT ["/app/.venv/bin/leggend"]
|
|
||||||
15
compose.yml
15
compose.yml
@@ -3,7 +3,6 @@ services:
|
|||||||
leggend:
|
leggend:
|
||||||
build:
|
build:
|
||||||
context: .
|
context: .
|
||||||
dockerfile: Dockerfile.leggend
|
|
||||||
restart: "unless-stopped"
|
restart: "unless-stopped"
|
||||||
ports:
|
ports:
|
||||||
- "127.0.0.1:8000:8000"
|
- "127.0.0.1:8000:8000"
|
||||||
@@ -18,20 +17,6 @@ services:
|
|||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 3
|
retries: 3
|
||||||
|
|
||||||
# CLI for one-off operations (uses leggend API)
|
|
||||||
leggen:
|
|
||||||
image: elisiariocouto/leggen:latest
|
|
||||||
command: sync --wait
|
|
||||||
restart: "no"
|
|
||||||
volumes:
|
|
||||||
- "./leggen:/root/.config/leggen"
|
|
||||||
- "./db:/app"
|
|
||||||
environment:
|
|
||||||
- LEGGEND_API_URL=http://leggend:8000
|
|
||||||
depends_on:
|
|
||||||
leggend:
|
|
||||||
condition: service_healthy
|
|
||||||
|
|
||||||
nocodb:
|
nocodb:
|
||||||
image: nocodb/nocodb:latest
|
image: nocodb/nocodb:latest
|
||||||
restart: "unless-stopped"
|
restart: "unless-stopped"
|
||||||
|
|||||||
@@ -8,19 +8,20 @@ from leggen.utils.text import error
|
|||||||
|
|
||||||
class LeggendAPIClient:
|
class LeggendAPIClient:
|
||||||
"""Client for communicating with the leggend FastAPI service"""
|
"""Client for communicating with the leggend FastAPI service"""
|
||||||
|
|
||||||
def __init__(self, base_url: Optional[str] = None):
|
def __init__(self, base_url: Optional[str] = None):
|
||||||
self.base_url = base_url or os.environ.get("LEGGEND_API_URL", "http://localhost:8000")
|
self.base_url = base_url or os.environ.get(
|
||||||
|
"LEGGEND_API_URL", "http://localhost:8000"
|
||||||
|
)
|
||||||
self.session = requests.Session()
|
self.session = requests.Session()
|
||||||
self.session.headers.update({
|
self.session.headers.update(
|
||||||
"Content-Type": "application/json",
|
{"Content-Type": "application/json", "Accept": "application/json"}
|
||||||
"Accept": "application/json"
|
)
|
||||||
})
|
|
||||||
|
|
||||||
def _make_request(self, method: str, endpoint: str, **kwargs) -> Dict[str, Any]:
|
def _make_request(self, method: str, endpoint: str, **kwargs) -> Dict[str, Any]:
|
||||||
"""Make HTTP request to the API"""
|
"""Make HTTP request to the API"""
|
||||||
url = urljoin(self.base_url, endpoint)
|
url = urljoin(self.base_url, endpoint)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.session.request(method, url, **kwargs)
|
response = self.session.request(method, url, **kwargs)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@@ -53,15 +54,19 @@ class LeggendAPIClient:
|
|||||||
# Bank endpoints
|
# Bank endpoints
|
||||||
def get_institutions(self, country: str = "PT") -> List[Dict[str, Any]]:
|
def get_institutions(self, country: str = "PT") -> List[Dict[str, Any]]:
|
||||||
"""Get bank institutions for a country"""
|
"""Get bank institutions for a country"""
|
||||||
response = self._make_request("GET", "/api/v1/banks/institutions", params={"country": country})
|
response = self._make_request(
|
||||||
|
"GET", "/api/v1/banks/institutions", params={"country": country}
|
||||||
|
)
|
||||||
return response.get("data", [])
|
return response.get("data", [])
|
||||||
|
|
||||||
def connect_to_bank(self, institution_id: str, redirect_url: str = "http://localhost:8000/") -> Dict[str, Any]:
|
def connect_to_bank(
|
||||||
|
self, institution_id: str, redirect_url: str = "http://localhost:8000/"
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Connect to a bank"""
|
"""Connect to a bank"""
|
||||||
response = self._make_request(
|
response = self._make_request(
|
||||||
"POST",
|
"POST",
|
||||||
"/api/v1/banks/connect",
|
"/api/v1/banks/connect",
|
||||||
json={"institution_id": institution_id, "redirect_url": redirect_url}
|
json={"institution_id": institution_id, "redirect_url": redirect_url},
|
||||||
)
|
)
|
||||||
return response.get("data", {})
|
return response.get("data", {})
|
||||||
|
|
||||||
@@ -91,31 +96,39 @@ class LeggendAPIClient:
|
|||||||
response = self._make_request("GET", f"/api/v1/accounts/{account_id}/balances")
|
response = self._make_request("GET", f"/api/v1/accounts/{account_id}/balances")
|
||||||
return response.get("data", [])
|
return response.get("data", [])
|
||||||
|
|
||||||
def get_account_transactions(self, account_id: str, limit: int = 100, summary_only: bool = False) -> List[Dict[str, Any]]:
|
def get_account_transactions(
|
||||||
|
self, account_id: str, limit: int = 100, summary_only: bool = False
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
"""Get account transactions"""
|
"""Get account transactions"""
|
||||||
response = self._make_request(
|
response = self._make_request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/api/v1/accounts/{account_id}/transactions",
|
f"/api/v1/accounts/{account_id}/transactions",
|
||||||
params={"limit": limit, "summary_only": summary_only}
|
params={"limit": limit, "summary_only": summary_only},
|
||||||
)
|
)
|
||||||
return response.get("data", [])
|
return response.get("data", [])
|
||||||
|
|
||||||
# Transaction endpoints
|
# Transaction endpoints
|
||||||
def get_all_transactions(self, limit: int = 100, summary_only: bool = True, **filters) -> List[Dict[str, Any]]:
|
def get_all_transactions(
|
||||||
|
self, limit: int = 100, summary_only: bool = True, **filters
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
"""Get all transactions with optional filters"""
|
"""Get all transactions with optional filters"""
|
||||||
params = {"limit": limit, "summary_only": summary_only}
|
params = {"limit": limit, "summary_only": summary_only}
|
||||||
params.update(filters)
|
params.update(filters)
|
||||||
|
|
||||||
response = self._make_request("GET", "/api/v1/transactions", params=params)
|
response = self._make_request("GET", "/api/v1/transactions", params=params)
|
||||||
return response.get("data", [])
|
return response.get("data", [])
|
||||||
|
|
||||||
def get_transaction_stats(self, days: int = 30, account_id: Optional[str] = None) -> Dict[str, Any]:
|
def get_transaction_stats(
|
||||||
|
self, days: int = 30, account_id: Optional[str] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Get transaction statistics"""
|
"""Get transaction statistics"""
|
||||||
params = {"days": days}
|
params = {"days": days}
|
||||||
if account_id:
|
if account_id:
|
||||||
params["account_id"] = account_id
|
params["account_id"] = account_id
|
||||||
|
|
||||||
response = self._make_request("GET", "/api/v1/transactions/stats", params=params)
|
response = self._make_request(
|
||||||
|
"GET", "/api/v1/transactions/stats", params=params
|
||||||
|
)
|
||||||
return response.get("data", {})
|
return response.get("data", {})
|
||||||
|
|
||||||
# Sync endpoints
|
# Sync endpoints
|
||||||
@@ -124,21 +137,25 @@ class LeggendAPIClient:
|
|||||||
response = self._make_request("GET", "/api/v1/sync/status")
|
response = self._make_request("GET", "/api/v1/sync/status")
|
||||||
return response.get("data", {})
|
return response.get("data", {})
|
||||||
|
|
||||||
def trigger_sync(self, account_ids: Optional[List[str]] = None, force: bool = False) -> Dict[str, Any]:
|
def trigger_sync(
|
||||||
|
self, account_ids: Optional[List[str]] = None, force: bool = False
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Trigger a sync"""
|
"""Trigger a sync"""
|
||||||
data = {"force": force}
|
data = {"force": force}
|
||||||
if account_ids:
|
if account_ids:
|
||||||
data["account_ids"] = account_ids
|
data["account_ids"] = account_ids
|
||||||
|
|
||||||
response = self._make_request("POST", "/api/v1/sync", json=data)
|
response = self._make_request("POST", "/api/v1/sync", json=data)
|
||||||
return response.get("data", {})
|
return response.get("data", {})
|
||||||
|
|
||||||
def sync_now(self, account_ids: Optional[List[str]] = None, force: bool = False) -> Dict[str, Any]:
|
def sync_now(
|
||||||
|
self, account_ids: Optional[List[str]] = None, force: bool = False
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Run sync synchronously"""
|
"""Run sync synchronously"""
|
||||||
data = {"force": force}
|
data = {"force": force}
|
||||||
if account_ids:
|
if account_ids:
|
||||||
data["account_ids"] = account_ids
|
data["account_ids"] = account_ids
|
||||||
|
|
||||||
response = self._make_request("POST", "/api/v1/sync/now", json=data)
|
response = self._make_request("POST", "/api/v1/sync/now", json=data)
|
||||||
return response.get("data", {})
|
return response.get("data", {})
|
||||||
|
|
||||||
@@ -147,11 +164,17 @@ class LeggendAPIClient:
|
|||||||
response = self._make_request("GET", "/api/v1/sync/scheduler")
|
response = self._make_request("GET", "/api/v1/sync/scheduler")
|
||||||
return response.get("data", {})
|
return response.get("data", {})
|
||||||
|
|
||||||
def update_scheduler_config(self, enabled: bool = True, hour: int = 3, minute: int = 0, cron: Optional[str] = None) -> Dict[str, Any]:
|
def update_scheduler_config(
|
||||||
|
self,
|
||||||
|
enabled: bool = True,
|
||||||
|
hour: int = 3,
|
||||||
|
minute: int = 0,
|
||||||
|
cron: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Update scheduler configuration"""
|
"""Update scheduler configuration"""
|
||||||
data = {"enabled": enabled, "hour": hour, "minute": minute}
|
data = {"enabled": enabled, "hour": hour, "minute": minute}
|
||||||
if cron:
|
if cron:
|
||||||
data["cron"] = cron
|
data["cron"] = cron
|
||||||
|
|
||||||
response = self._make_request("PUT", "/api/v1/sync/scheduler", json=data)
|
response = self._make_request("PUT", "/api/v1/sync/scheduler", json=data)
|
||||||
return response.get("data", {})
|
return response.get("data", {})
|
||||||
|
|||||||
@@ -12,10 +12,12 @@ def balances(ctx: click.Context):
|
|||||||
List balances of all connected accounts
|
List balances of all connected accounts
|
||||||
"""
|
"""
|
||||||
api_client = LeggendAPIClient(ctx.obj.get("api_url"))
|
api_client = LeggendAPIClient(ctx.obj.get("api_url"))
|
||||||
|
|
||||||
# Check if leggend service is available
|
# Check if leggend service is available
|
||||||
if not api_client.health_check():
|
if not api_client.health_check():
|
||||||
click.echo("Error: Cannot connect to leggend service. Please ensure it's running.")
|
click.echo(
|
||||||
|
"Error: Cannot connect to leggend service. Please ensure it's running."
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
accounts = api_client.get_accounts()
|
accounts = api_client.get_accounts()
|
||||||
@@ -24,11 +26,7 @@ def balances(ctx: click.Context):
|
|||||||
for account in accounts:
|
for account in accounts:
|
||||||
for balance in account.get("balances", []):
|
for balance in account.get("balances", []):
|
||||||
amount = round(float(balance["amount"]), 2)
|
amount = round(float(balance["amount"]), 2)
|
||||||
symbol = (
|
symbol = "€" if balance["currency"] == "EUR" else f" {balance['currency']}"
|
||||||
"€"
|
|
||||||
if balance["currency"] == "EUR"
|
|
||||||
else f" {balance['currency']}"
|
|
||||||
)
|
|
||||||
amount_str = f"{amount}{symbol}"
|
amount_str = f"{amount}{symbol}"
|
||||||
date = (
|
date = (
|
||||||
datefmt(balance.get("last_change_date"))
|
datefmt(balance.get("last_change_date"))
|
||||||
|
|||||||
@@ -1,36 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import click
|
|
||||||
|
|
||||||
from leggen.main import cli
|
|
||||||
|
|
||||||
cmd_folder = os.path.abspath(os.path.dirname(__file__))
|
|
||||||
|
|
||||||
|
|
||||||
class BankGroup(click.Group):
|
|
||||||
def list_commands(self, ctx):
|
|
||||||
rv = []
|
|
||||||
for filename in os.listdir(cmd_folder):
|
|
||||||
if filename.endswith(".py") and not filename.startswith("__init__"):
|
|
||||||
if filename == "list_banks.py":
|
|
||||||
rv.append("list")
|
|
||||||
else:
|
|
||||||
rv.append(filename[:-3])
|
|
||||||
rv.sort()
|
|
||||||
return rv
|
|
||||||
|
|
||||||
def get_command(self, ctx, name):
|
|
||||||
try:
|
|
||||||
if name == "list":
|
|
||||||
name = "list_banks"
|
|
||||||
mod = __import__(f"leggen.commands.bank.{name}", None, None, [name])
|
|
||||||
except ImportError:
|
|
||||||
return
|
|
||||||
return getattr(mod, name)
|
|
||||||
|
|
||||||
|
|
||||||
@cli.group(cls=BankGroup)
|
|
||||||
@click.pass_context
|
|
||||||
def bank(ctx):
|
|
||||||
"""Manage banks connections"""
|
|
||||||
return
|
|
||||||
@@ -13,30 +13,32 @@ def add(ctx):
|
|||||||
Connect to a bank
|
Connect to a bank
|
||||||
"""
|
"""
|
||||||
api_client = LeggendAPIClient(ctx.obj.get("api_url"))
|
api_client = LeggendAPIClient(ctx.obj.get("api_url"))
|
||||||
|
|
||||||
# Check if leggend service is available
|
# Check if leggend service is available
|
||||||
if not api_client.health_check():
|
if not api_client.health_check():
|
||||||
click.echo("Error: Cannot connect to leggend service. Please ensure it's running.")
|
click.echo(
|
||||||
|
"Error: Cannot connect to leggend service. Please ensure it's running."
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get supported countries
|
# Get supported countries
|
||||||
countries = api_client.get_supported_countries()
|
countries = api_client.get_supported_countries()
|
||||||
country_codes = [c["code"] for c in countries]
|
country_codes = [c["code"] for c in countries]
|
||||||
|
|
||||||
country = click.prompt(
|
country = click.prompt(
|
||||||
"Bank Country",
|
"Bank Country",
|
||||||
type=click.Choice(country_codes, case_sensitive=True),
|
type=click.Choice(country_codes, case_sensitive=True),
|
||||||
default="PT",
|
default="PT",
|
||||||
)
|
)
|
||||||
|
|
||||||
info(f"Getting bank list for country: {country}")
|
info(f"Getting bank list for country: {country}")
|
||||||
banks = api_client.get_institutions(country)
|
banks = api_client.get_institutions(country)
|
||||||
|
|
||||||
if not banks:
|
if not banks:
|
||||||
warning(f"No banks available for country {country}")
|
warning(f"No banks available for country {country}")
|
||||||
return
|
return
|
||||||
|
|
||||||
filtered_banks = [
|
filtered_banks = [
|
||||||
{
|
{
|
||||||
"id": bank["id"],
|
"id": bank["id"],
|
||||||
@@ -46,14 +48,14 @@ def add(ctx):
|
|||||||
for bank in banks
|
for bank in banks
|
||||||
]
|
]
|
||||||
print_table(filtered_banks)
|
print_table(filtered_banks)
|
||||||
|
|
||||||
allowed_ids = [str(bank["id"]) for bank in banks]
|
allowed_ids = [str(bank["id"]) for bank in banks]
|
||||||
bank_id = click.prompt("Bank ID", type=click.Choice(allowed_ids))
|
bank_id = click.prompt("Bank ID", type=click.Choice(allowed_ids))
|
||||||
|
|
||||||
# Show bank details
|
# Show bank details
|
||||||
selected_bank = next(bank for bank in banks if bank["id"] == bank_id)
|
selected_bank = next(bank for bank in banks if bank["id"] == bank_id)
|
||||||
info(f"Selected bank: {selected_bank['name']}")
|
info(f"Selected bank: {selected_bank['name']}")
|
||||||
|
|
||||||
click.confirm("Do you agree to connect to this bank?", abort=True)
|
click.confirm("Do you agree to connect to this bank?", abort=True)
|
||||||
|
|
||||||
info(f"Connecting to bank with ID: {bank_id}")
|
info(f"Connecting to bank with ID: {bank_id}")
|
||||||
@@ -65,11 +67,15 @@ def add(ctx):
|
|||||||
save_file(f"req_{result['id']}.json", result)
|
save_file(f"req_{result['id']}.json", result)
|
||||||
|
|
||||||
success("Bank connection request created successfully!")
|
success("Bank connection request created successfully!")
|
||||||
warning(f"Please open the following URL in your browser to complete the authorization:")
|
warning(
|
||||||
|
f"Please open the following URL in your browser to complete the authorization:"
|
||||||
|
)
|
||||||
click.echo(f"\n{result['link']}\n")
|
click.echo(f"\n{result['link']}\n")
|
||||||
|
|
||||||
info(f"Requisition ID: {result['id']}")
|
info(f"Requisition ID: {result['id']}")
|
||||||
info("After completing the authorization, you can check the connection status with 'leggen status'")
|
info(
|
||||||
|
"After completing the authorization, you can check the connection status with 'leggen status'"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(f"Error: Failed to connect to bank: {str(e)}")
|
click.echo(f"Error: Failed to connect to bank: {str(e)}")
|
||||||
|
|||||||
@@ -12,10 +12,12 @@ def status(ctx: click.Context):
|
|||||||
List all connected banks and their status
|
List all connected banks and their status
|
||||||
"""
|
"""
|
||||||
api_client = LeggendAPIClient(ctx.obj.get("api_url"))
|
api_client = LeggendAPIClient(ctx.obj.get("api_url"))
|
||||||
|
|
||||||
# Check if leggend service is available
|
# Check if leggend service is available
|
||||||
if not api_client.health_check():
|
if not api_client.health_check():
|
||||||
click.echo("Error: Cannot connect to leggend service. Please ensure it's running.")
|
click.echo(
|
||||||
|
"Error: Cannot connect to leggend service. Please ensure it's running."
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get bank connection status
|
# Get bank connection status
|
||||||
|
|||||||
@@ -6,15 +6,15 @@ from leggen.utils.text import error, info, success
|
|||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.option('--wait', is_flag=True, help='Wait for sync to complete (synchronous)')
|
@click.option("--wait", is_flag=True, help="Wait for sync to complete (synchronous)")
|
||||||
@click.option('--force', is_flag=True, help='Force sync even if already running')
|
@click.option("--force", is_flag=True, help="Force sync even if already running")
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def sync(ctx: click.Context, wait: bool, force: bool):
|
def sync(ctx: click.Context, wait: bool, force: bool):
|
||||||
"""
|
"""
|
||||||
Sync all transactions with database
|
Sync all transactions with database
|
||||||
"""
|
"""
|
||||||
api_client = LeggendAPIClient(ctx.obj.get("api_url"))
|
api_client = LeggendAPIClient(ctx.obj.get("api_url"))
|
||||||
|
|
||||||
# Check if leggend service is available
|
# Check if leggend service is available
|
||||||
if not api_client.health_check():
|
if not api_client.health_check():
|
||||||
error("Cannot connect to leggend service. Please ensure it's running.")
|
error("Cannot connect to leggend service. Please ensure it's running.")
|
||||||
@@ -25,35 +25,37 @@ def sync(ctx: click.Context, wait: bool, force: bool):
|
|||||||
# Run sync synchronously and wait for completion
|
# Run sync synchronously and wait for completion
|
||||||
info("Starting synchronous sync...")
|
info("Starting synchronous sync...")
|
||||||
result = api_client.sync_now(force=force)
|
result = api_client.sync_now(force=force)
|
||||||
|
|
||||||
if result.get("success"):
|
if result.get("success"):
|
||||||
success(f"Sync completed successfully!")
|
success(f"Sync completed successfully!")
|
||||||
info(f"Accounts processed: {result.get('accounts_processed', 0)}")
|
info(f"Accounts processed: {result.get('accounts_processed', 0)}")
|
||||||
info(f"Transactions added: {result.get('transactions_added', 0)}")
|
info(f"Transactions added: {result.get('transactions_added', 0)}")
|
||||||
info(f"Balances updated: {result.get('balances_updated', 0)}")
|
info(f"Balances updated: {result.get('balances_updated', 0)}")
|
||||||
if result.get('duration_seconds'):
|
if result.get("duration_seconds"):
|
||||||
info(f"Duration: {result['duration_seconds']:.2f} seconds")
|
info(f"Duration: {result['duration_seconds']:.2f} seconds")
|
||||||
|
|
||||||
if result.get('errors'):
|
if result.get("errors"):
|
||||||
error(f"Errors encountered: {len(result['errors'])}")
|
error(f"Errors encountered: {len(result['errors'])}")
|
||||||
for err in result['errors']:
|
for err in result["errors"]:
|
||||||
error(f" - {err}")
|
error(f" - {err}")
|
||||||
else:
|
else:
|
||||||
error("Sync failed")
|
error("Sync failed")
|
||||||
if result.get('errors'):
|
if result.get("errors"):
|
||||||
for err in result['errors']:
|
for err in result["errors"]:
|
||||||
error(f" - {err}")
|
error(f" - {err}")
|
||||||
else:
|
else:
|
||||||
# Trigger async sync
|
# Trigger async sync
|
||||||
info("Starting background sync...")
|
info("Starting background sync...")
|
||||||
result = api_client.trigger_sync(force=force)
|
result = api_client.trigger_sync(force=force)
|
||||||
|
|
||||||
if result.get("sync_started"):
|
if result.get("sync_started"):
|
||||||
success("Sync started successfully in the background")
|
success("Sync started successfully in the background")
|
||||||
info("Use 'leggen sync --wait' to run synchronously or check status with API")
|
info(
|
||||||
|
"Use 'leggen sync --wait' to run synchronously or check status with API"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
error("Failed to start sync")
|
error("Failed to start sync")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f"Sync failed: {str(e)}")
|
error(f"Sync failed: {str(e)}")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -7,7 +7,9 @@ from leggen.utils.text import datefmt, info, print_table
|
|||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.option("-a", "--account", type=str, help="Account ID")
|
@click.option("-a", "--account", type=str, help="Account ID")
|
||||||
@click.option("-l", "--limit", type=int, default=50, help="Number of transactions to show")
|
@click.option(
|
||||||
|
"-l", "--limit", type=int, default=50, help="Number of transactions to show"
|
||||||
|
)
|
||||||
@click.option("--full", is_flag=True, help="Show full transaction details")
|
@click.option("--full", is_flag=True, help="Show full transaction details")
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def transactions(ctx: click.Context, account: str, limit: int, full: bool):
|
def transactions(ctx: click.Context, account: str, limit: int, full: bool):
|
||||||
@@ -19,10 +21,12 @@ def transactions(ctx: click.Context, account: str, limit: int, full: bool):
|
|||||||
If the --account option is used, it will only list transactions for that account.
|
If the --account option is used, it will only list transactions for that account.
|
||||||
"""
|
"""
|
||||||
api_client = LeggendAPIClient(ctx.obj.get("api_url"))
|
api_client = LeggendAPIClient(ctx.obj.get("api_url"))
|
||||||
|
|
||||||
# Check if leggend service is available
|
# Check if leggend service is available
|
||||||
if not api_client.health_check():
|
if not api_client.health_check():
|
||||||
click.echo("Error: Cannot connect to leggend service. Please ensure it's running.")
|
click.echo(
|
||||||
|
"Error: Cannot connect to leggend service. Please ensure it's running."
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -32,16 +36,14 @@ def transactions(ctx: click.Context, account: str, limit: int, full: bool):
|
|||||||
transactions_data = api_client.get_account_transactions(
|
transactions_data = api_client.get_account_transactions(
|
||||||
account, limit=limit, summary_only=not full
|
account, limit=limit, summary_only=not full
|
||||||
)
|
)
|
||||||
|
|
||||||
info(f"Bank: {account_details['institution_id']}")
|
info(f"Bank: {account_details['institution_id']}")
|
||||||
info(f"IBAN: {account_details.get('iban', 'N/A')}")
|
info(f"IBAN: {account_details.get('iban', 'N/A')}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Get all transactions
|
# Get all transactions
|
||||||
transactions_data = api_client.get_all_transactions(
|
transactions_data = api_client.get_all_transactions(
|
||||||
limit=limit,
|
limit=limit, summary_only=not full, account_id=account
|
||||||
summary_only=not full,
|
|
||||||
account_id=account
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Format transactions for display
|
# Format transactions for display
|
||||||
@@ -49,24 +51,32 @@ def transactions(ctx: click.Context, account: str, limit: int, full: bool):
|
|||||||
# Full transaction details
|
# Full transaction details
|
||||||
formatted_transactions = []
|
formatted_transactions = []
|
||||||
for txn in transactions_data:
|
for txn in transactions_data:
|
||||||
formatted_transactions.append({
|
formatted_transactions.append(
|
||||||
"ID": txn["internal_transaction_id"][:12] + "...",
|
{
|
||||||
"Date": datefmt(txn["transaction_date"]),
|
"ID": txn["internal_transaction_id"][:12] + "...",
|
||||||
"Description": txn["description"][:50] + "..." if len(txn["description"]) > 50 else txn["description"],
|
"Date": datefmt(txn["transaction_date"]),
|
||||||
"Amount": f"{txn['transaction_value']:.2f} {txn['transaction_currency']}",
|
"Description": txn["description"][:50] + "..."
|
||||||
"Status": txn["transaction_status"].upper(),
|
if len(txn["description"]) > 50
|
||||||
"Account": txn["account_id"][:8] + "...",
|
else txn["description"],
|
||||||
})
|
"Amount": f"{txn['transaction_value']:.2f} {txn['transaction_currency']}",
|
||||||
|
"Status": txn["transaction_status"].upper(),
|
||||||
|
"Account": txn["account_id"][:8] + "...",
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Summary view
|
# Summary view
|
||||||
formatted_transactions = []
|
formatted_transactions = []
|
||||||
for txn in transactions_data:
|
for txn in transactions_data:
|
||||||
formatted_transactions.append({
|
formatted_transactions.append(
|
||||||
"Date": datefmt(txn["date"]),
|
{
|
||||||
"Description": txn["description"][:60] + "..." if len(txn["description"]) > 60 else txn["description"],
|
"Date": datefmt(txn["date"]),
|
||||||
"Amount": f"{txn['amount']:.2f} {txn['currency']}",
|
"Description": txn["description"][:60] + "..."
|
||||||
"Status": txn["status"].upper(),
|
if len(txn["description"]) > 60
|
||||||
})
|
else txn["description"],
|
||||||
|
"Amount": f"{txn['amount']:.2f} {txn['currency']}",
|
||||||
|
"Status": txn["status"].upper(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if formatted_transactions:
|
if formatted_transactions:
|
||||||
print_table(formatted_transactions)
|
print_table(formatted_transactions)
|
||||||
|
|||||||
@@ -90,10 +90,10 @@ class Group(click.Group):
|
|||||||
@click.option(
|
@click.option(
|
||||||
"--api-url",
|
"--api-url",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default="http://localhost:8000",
|
||||||
envvar="LEGGEND_API_URL",
|
envvar="LEGGEND_API_URL",
|
||||||
show_envvar=True,
|
show_envvar=True,
|
||||||
help="URL of the leggend API service (default: http://localhost:8000)",
|
help="URL of the leggend API service",
|
||||||
)
|
)
|
||||||
@click.group(
|
@click.group(
|
||||||
cls=Group,
|
cls=Group,
|
||||||
@@ -113,7 +113,7 @@ def cli(ctx: click.Context, api_url: str):
|
|||||||
# Store API URL in context for commands to use
|
# Store API URL in context for commands to use
|
||||||
if api_url:
|
if api_url:
|
||||||
ctx.obj["api_url"] = api_url
|
ctx.obj["api_url"] = api_url
|
||||||
|
|
||||||
# For backwards compatibility, still support direct GoCardless calls
|
# For backwards compatibility, still support direct GoCardless calls
|
||||||
# This will be used as fallback if leggend service is not available
|
# This will be used as fallback if leggend service is not available
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -6,19 +6,19 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
class AccountBalance(BaseModel):
|
class AccountBalance(BaseModel):
|
||||||
"""Account balance model"""
|
"""Account balance model"""
|
||||||
|
|
||||||
amount: float
|
amount: float
|
||||||
currency: str
|
currency: str
|
||||||
balance_type: str
|
balance_type: str
|
||||||
last_change_date: Optional[datetime] = None
|
last_change_date: Optional[datetime] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
json_encoders = {
|
json_encoders = {datetime: lambda v: v.isoformat() if v else None}
|
||||||
datetime: lambda v: v.isoformat() if v else None
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class AccountDetails(BaseModel):
|
class AccountDetails(BaseModel):
|
||||||
"""Account details model"""
|
"""Account details model"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
institution_id: str
|
institution_id: str
|
||||||
status: str
|
status: str
|
||||||
@@ -28,15 +28,14 @@ class AccountDetails(BaseModel):
|
|||||||
created: datetime
|
created: datetime
|
||||||
last_accessed: Optional[datetime] = None
|
last_accessed: Optional[datetime] = None
|
||||||
balances: List[AccountBalance] = []
|
balances: List[AccountBalance] = []
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
json_encoders = {
|
json_encoders = {datetime: lambda v: v.isoformat() if v else None}
|
||||||
datetime: lambda v: v.isoformat() if v else None
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class Transaction(BaseModel):
|
class Transaction(BaseModel):
|
||||||
"""Transaction model"""
|
"""Transaction model"""
|
||||||
|
|
||||||
internal_transaction_id: str
|
internal_transaction_id: str
|
||||||
institution_id: str
|
institution_id: str
|
||||||
iban: Optional[str] = None
|
iban: Optional[str] = None
|
||||||
@@ -47,15 +46,14 @@ class Transaction(BaseModel):
|
|||||||
transaction_currency: str
|
transaction_currency: str
|
||||||
transaction_status: str # "booked" or "pending"
|
transaction_status: str # "booked" or "pending"
|
||||||
raw_transaction: Dict[str, Any]
|
raw_transaction: Dict[str, Any]
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
json_encoders = {
|
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||||
datetime: lambda v: v.isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TransactionSummary(BaseModel):
|
class TransactionSummary(BaseModel):
|
||||||
"""Transaction summary for lists"""
|
"""Transaction summary for lists"""
|
||||||
|
|
||||||
internal_transaction_id: str
|
internal_transaction_id: str
|
||||||
date: datetime
|
date: datetime
|
||||||
description: str
|
description: str
|
||||||
@@ -63,8 +61,6 @@ class TransactionSummary(BaseModel):
|
|||||||
currency: str
|
currency: str
|
||||||
status: str
|
status: str
|
||||||
account_id: str
|
account_id: str
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
json_encoders = {
|
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||||
datetime: lambda v: v.isoformat()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
class BankInstitution(BaseModel):
|
class BankInstitution(BaseModel):
|
||||||
"""Bank institution model"""
|
"""Bank institution model"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
bic: Optional[str] = None
|
bic: Optional[str] = None
|
||||||
@@ -16,12 +17,14 @@ class BankInstitution(BaseModel):
|
|||||||
|
|
||||||
class BankConnectionRequest(BaseModel):
|
class BankConnectionRequest(BaseModel):
|
||||||
"""Request to connect to a bank"""
|
"""Request to connect to a bank"""
|
||||||
|
|
||||||
institution_id: str
|
institution_id: str
|
||||||
redirect_url: Optional[str] = "http://localhost:8000/"
|
redirect_url: Optional[str] = "http://localhost:8000/"
|
||||||
|
|
||||||
|
|
||||||
class BankRequisition(BaseModel):
|
class BankRequisition(BaseModel):
|
||||||
"""Bank requisition/connection model"""
|
"""Bank requisition/connection model"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
institution_id: str
|
institution_id: str
|
||||||
status: str
|
status: str
|
||||||
@@ -29,15 +32,14 @@ class BankRequisition(BaseModel):
|
|||||||
created: datetime
|
created: datetime
|
||||||
link: str
|
link: str
|
||||||
accounts: List[str] = []
|
accounts: List[str] = []
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
json_encoders = {
|
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||||
datetime: lambda v: v.isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class BankConnectionStatus(BaseModel):
|
class BankConnectionStatus(BaseModel):
|
||||||
"""Bank connection status response"""
|
"""Bank connection status response"""
|
||||||
|
|
||||||
bank_id: str
|
bank_id: str
|
||||||
bank_name: str
|
bank_name: str
|
||||||
status: str
|
status: str
|
||||||
@@ -45,8 +47,6 @@ class BankConnectionStatus(BaseModel):
|
|||||||
created_at: datetime
|
created_at: datetime
|
||||||
requisition_id: str
|
requisition_id: str
|
||||||
accounts_count: int
|
accounts_count: int
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
json_encoders = {
|
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||||
datetime: lambda v: v.isoformat()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
class APIResponse(BaseModel):
|
class APIResponse(BaseModel):
|
||||||
"""Base API response model"""
|
"""Base API response model"""
|
||||||
|
|
||||||
success: bool = True
|
success: bool = True
|
||||||
message: Optional[str] = None
|
message: Optional[str] = None
|
||||||
data: Optional[Any] = None
|
data: Optional[Any] = None
|
||||||
@@ -13,6 +14,7 @@ class APIResponse(BaseModel):
|
|||||||
|
|
||||||
class ErrorResponse(BaseModel):
|
class ErrorResponse(BaseModel):
|
||||||
"""Error response model"""
|
"""Error response model"""
|
||||||
|
|
||||||
success: bool = False
|
success: bool = False
|
||||||
message: str
|
message: str
|
||||||
error_code: Optional[str] = None
|
error_code: Optional[str] = None
|
||||||
@@ -21,7 +23,8 @@ class ErrorResponse(BaseModel):
|
|||||||
|
|
||||||
class PaginatedResponse(BaseModel):
|
class PaginatedResponse(BaseModel):
|
||||||
"""Paginated response model"""
|
"""Paginated response model"""
|
||||||
|
|
||||||
success: bool = True
|
success: bool = True
|
||||||
data: list
|
data: list
|
||||||
pagination: Dict[str, Any]
|
pagination: Dict[str, Any]
|
||||||
message: Optional[str] = None
|
message: Optional[str] = None
|
||||||
|
|||||||
@@ -5,12 +5,14 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
class DiscordConfig(BaseModel):
|
class DiscordConfig(BaseModel):
|
||||||
"""Discord notification configuration"""
|
"""Discord notification configuration"""
|
||||||
|
|
||||||
webhook: str
|
webhook: str
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
|
|
||||||
|
|
||||||
class TelegramConfig(BaseModel):
|
class TelegramConfig(BaseModel):
|
||||||
"""Telegram notification configuration"""
|
"""Telegram notification configuration"""
|
||||||
|
|
||||||
token: str
|
token: str
|
||||||
chat_id: int
|
chat_id: int
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
@@ -18,6 +20,7 @@ class TelegramConfig(BaseModel):
|
|||||||
|
|
||||||
class NotificationFilters(BaseModel):
|
class NotificationFilters(BaseModel):
|
||||||
"""Notification filters configuration"""
|
"""Notification filters configuration"""
|
||||||
|
|
||||||
case_insensitive: Dict[str, str] = {}
|
case_insensitive: Dict[str, str] = {}
|
||||||
case_sensitive: Optional[Dict[str, str]] = None
|
case_sensitive: Optional[Dict[str, str]] = None
|
||||||
amount_threshold: Optional[float] = None
|
amount_threshold: Optional[float] = None
|
||||||
@@ -26,6 +29,7 @@ class NotificationFilters(BaseModel):
|
|||||||
|
|
||||||
class NotificationSettings(BaseModel):
|
class NotificationSettings(BaseModel):
|
||||||
"""Complete notification settings"""
|
"""Complete notification settings"""
|
||||||
|
|
||||||
discord: Optional[DiscordConfig] = None
|
discord: Optional[DiscordConfig] = None
|
||||||
telegram: Optional[TelegramConfig] = None
|
telegram: Optional[TelegramConfig] = None
|
||||||
filters: NotificationFilters = NotificationFilters()
|
filters: NotificationFilters = NotificationFilters()
|
||||||
@@ -33,15 +37,17 @@ class NotificationSettings(BaseModel):
|
|||||||
|
|
||||||
class NotificationTest(BaseModel):
|
class NotificationTest(BaseModel):
|
||||||
"""Test notification request"""
|
"""Test notification request"""
|
||||||
|
|
||||||
service: str # "discord" or "telegram"
|
service: str # "discord" or "telegram"
|
||||||
message: str = "Test notification from Leggen"
|
message: str = "Test notification from Leggen"
|
||||||
|
|
||||||
|
|
||||||
class NotificationHistory(BaseModel):
|
class NotificationHistory(BaseModel):
|
||||||
"""Notification history entry"""
|
"""Notification history entry"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
service: str
|
service: str
|
||||||
message: str
|
message: str
|
||||||
status: str # "sent", "failed"
|
status: str # "sent", "failed"
|
||||||
sent_at: str
|
sent_at: str
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
|
|||||||
@@ -6,12 +6,14 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
class SyncRequest(BaseModel):
|
class SyncRequest(BaseModel):
|
||||||
"""Request to trigger a sync"""
|
"""Request to trigger a sync"""
|
||||||
|
|
||||||
account_ids: Optional[list[str]] = None # If None, sync all accounts
|
account_ids: Optional[list[str]] = None # If None, sync all accounts
|
||||||
force: bool = False # Force sync even if recently synced
|
force: bool = False # Force sync even if recently synced
|
||||||
|
|
||||||
|
|
||||||
class SyncStatus(BaseModel):
|
class SyncStatus(BaseModel):
|
||||||
"""Sync operation status"""
|
"""Sync operation status"""
|
||||||
|
|
||||||
is_running: bool
|
is_running: bool
|
||||||
last_sync: Optional[datetime] = None
|
last_sync: Optional[datetime] = None
|
||||||
next_sync: Optional[datetime] = None
|
next_sync: Optional[datetime] = None
|
||||||
@@ -19,15 +21,14 @@ class SyncStatus(BaseModel):
|
|||||||
total_accounts: int = 0
|
total_accounts: int = 0
|
||||||
transactions_added: int = 0
|
transactions_added: int = 0
|
||||||
errors: list[str] = []
|
errors: list[str] = []
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
json_encoders = {
|
json_encoders = {datetime: lambda v: v.isoformat() if v else None}
|
||||||
datetime: lambda v: v.isoformat() if v else None
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class SyncResult(BaseModel):
|
class SyncResult(BaseModel):
|
||||||
"""Result of a sync operation"""
|
"""Result of a sync operation"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
accounts_processed: int
|
accounts_processed: int
|
||||||
transactions_added: int
|
transactions_added: int
|
||||||
@@ -37,19 +38,18 @@ class SyncResult(BaseModel):
|
|||||||
errors: list[str] = []
|
errors: list[str] = []
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
completed_at: datetime
|
completed_at: datetime
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
json_encoders = {
|
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||||
datetime: lambda v: v.isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class SchedulerConfig(BaseModel):
|
class SchedulerConfig(BaseModel):
|
||||||
"""Scheduler configuration model"""
|
"""Scheduler configuration model"""
|
||||||
|
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
hour: Optional[int] = 3
|
hour: Optional[int] = 3
|
||||||
minute: Optional[int] = 0
|
minute: Optional[int] = 0
|
||||||
cron: Optional[str] = None # Custom cron expression
|
cron: Optional[str] = None # Custom cron expression
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
extra = "forbid"
|
extra = "forbid"
|
||||||
|
|||||||
@@ -3,7 +3,12 @@ from fastapi import APIRouter, HTTPException, Query
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from leggend.api.models.common import APIResponse
|
from leggend.api.models.common import APIResponse
|
||||||
from leggend.api.models.accounts import AccountDetails, AccountBalance, Transaction, TransactionSummary
|
from leggend.api.models.accounts import (
|
||||||
|
AccountDetails,
|
||||||
|
AccountBalance,
|
||||||
|
Transaction,
|
||||||
|
TransactionSummary,
|
||||||
|
)
|
||||||
from leggend.services.gocardless_service import GoCardlessService
|
from leggend.services.gocardless_service import GoCardlessService
|
||||||
from leggend.services.database_service import DatabaseService
|
from leggend.services.database_service import DatabaseService
|
||||||
|
|
||||||
@@ -17,50 +22,56 @@ async def get_all_accounts() -> APIResponse:
|
|||||||
"""Get all connected accounts"""
|
"""Get all connected accounts"""
|
||||||
try:
|
try:
|
||||||
requisitions_data = await gocardless_service.get_requisitions()
|
requisitions_data = await gocardless_service.get_requisitions()
|
||||||
|
|
||||||
all_accounts = set()
|
all_accounts = set()
|
||||||
for req in requisitions_data.get("results", []):
|
for req in requisitions_data.get("results", []):
|
||||||
all_accounts.update(req.get("accounts", []))
|
all_accounts.update(req.get("accounts", []))
|
||||||
|
|
||||||
accounts = []
|
accounts = []
|
||||||
for account_id in all_accounts:
|
for account_id in all_accounts:
|
||||||
try:
|
try:
|
||||||
account_details = await gocardless_service.get_account_details(account_id)
|
account_details = await gocardless_service.get_account_details(
|
||||||
balances_data = await gocardless_service.get_account_balances(account_id)
|
account_id
|
||||||
|
)
|
||||||
|
balances_data = await gocardless_service.get_account_balances(
|
||||||
|
account_id
|
||||||
|
)
|
||||||
|
|
||||||
# Process balances
|
# Process balances
|
||||||
balances = []
|
balances = []
|
||||||
for balance in balances_data.get("balances", []):
|
for balance in balances_data.get("balances", []):
|
||||||
balance_amount = balance["balanceAmount"]
|
balance_amount = balance["balanceAmount"]
|
||||||
balances.append(AccountBalance(
|
balances.append(
|
||||||
amount=float(balance_amount["amount"]),
|
AccountBalance(
|
||||||
currency=balance_amount["currency"],
|
amount=float(balance_amount["amount"]),
|
||||||
balance_type=balance["balanceType"],
|
currency=balance_amount["currency"],
|
||||||
last_change_date=balance.get("lastChangeDateTime")
|
balance_type=balance["balanceType"],
|
||||||
))
|
last_change_date=balance.get("lastChangeDateTime"),
|
||||||
|
)
|
||||||
accounts.append(AccountDetails(
|
)
|
||||||
id=account_details["id"],
|
|
||||||
institution_id=account_details["institution_id"],
|
accounts.append(
|
||||||
status=account_details["status"],
|
AccountDetails(
|
||||||
iban=account_details.get("iban"),
|
id=account_details["id"],
|
||||||
name=account_details.get("name"),
|
institution_id=account_details["institution_id"],
|
||||||
currency=account_details.get("currency"),
|
status=account_details["status"],
|
||||||
created=account_details["created"],
|
iban=account_details.get("iban"),
|
||||||
last_accessed=account_details.get("last_accessed"),
|
name=account_details.get("name"),
|
||||||
balances=balances
|
currency=account_details.get("currency"),
|
||||||
))
|
created=account_details["created"],
|
||||||
|
last_accessed=account_details.get("last_accessed"),
|
||||||
|
balances=balances,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get details for account {account_id}: {e}")
|
logger.error(f"Failed to get details for account {account_id}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
success=True,
|
success=True, data=accounts, message=f"Retrieved {len(accounts)} accounts"
|
||||||
data=accounts,
|
|
||||||
message=f"Retrieved {len(accounts)} accounts"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get accounts: {e}")
|
logger.error(f"Failed to get accounts: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to get accounts: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Failed to get accounts: {str(e)}")
|
||||||
@@ -72,18 +83,20 @@ async def get_account_details(account_id: str) -> APIResponse:
|
|||||||
try:
|
try:
|
||||||
account_details = await gocardless_service.get_account_details(account_id)
|
account_details = await gocardless_service.get_account_details(account_id)
|
||||||
balances_data = await gocardless_service.get_account_balances(account_id)
|
balances_data = await gocardless_service.get_account_balances(account_id)
|
||||||
|
|
||||||
# Process balances
|
# Process balances
|
||||||
balances = []
|
balances = []
|
||||||
for balance in balances_data.get("balances", []):
|
for balance in balances_data.get("balances", []):
|
||||||
balance_amount = balance["balanceAmount"]
|
balance_amount = balance["balanceAmount"]
|
||||||
balances.append(AccountBalance(
|
balances.append(
|
||||||
amount=float(balance_amount["amount"]),
|
AccountBalance(
|
||||||
currency=balance_amount["currency"],
|
amount=float(balance_amount["amount"]),
|
||||||
balance_type=balance["balanceType"],
|
currency=balance_amount["currency"],
|
||||||
last_change_date=balance.get("lastChangeDateTime")
|
balance_type=balance["balanceType"],
|
||||||
))
|
last_change_date=balance.get("lastChangeDateTime"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
account = AccountDetails(
|
account = AccountDetails(
|
||||||
id=account_details["id"],
|
id=account_details["id"],
|
||||||
institution_id=account_details["institution_id"],
|
institution_id=account_details["institution_id"],
|
||||||
@@ -93,15 +106,15 @@ async def get_account_details(account_id: str) -> APIResponse:
|
|||||||
currency=account_details.get("currency"),
|
currency=account_details.get("currency"),
|
||||||
created=account_details["created"],
|
created=account_details["created"],
|
||||||
last_accessed=account_details.get("last_accessed"),
|
last_accessed=account_details.get("last_accessed"),
|
||||||
balances=balances
|
balances=balances,
|
||||||
)
|
)
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
success=True,
|
success=True,
|
||||||
data=account,
|
data=account,
|
||||||
message=f"Account details retrieved for {account_id}"
|
message=f"Account details retrieved for {account_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get account details for {account_id}: {e}")
|
logger.error(f"Failed to get account details for {account_id}: {e}")
|
||||||
raise HTTPException(status_code=404, detail=f"Account not found: {str(e)}")
|
raise HTTPException(status_code=404, detail=f"Account not found: {str(e)}")
|
||||||
@@ -112,23 +125,25 @@ async def get_account_balances(account_id: str) -> APIResponse:
|
|||||||
"""Get balances for a specific account"""
|
"""Get balances for a specific account"""
|
||||||
try:
|
try:
|
||||||
balances_data = await gocardless_service.get_account_balances(account_id)
|
balances_data = await gocardless_service.get_account_balances(account_id)
|
||||||
|
|
||||||
balances = []
|
balances = []
|
||||||
for balance in balances_data.get("balances", []):
|
for balance in balances_data.get("balances", []):
|
||||||
balance_amount = balance["balanceAmount"]
|
balance_amount = balance["balanceAmount"]
|
||||||
balances.append(AccountBalance(
|
balances.append(
|
||||||
amount=float(balance_amount["amount"]),
|
AccountBalance(
|
||||||
currency=balance_amount["currency"],
|
amount=float(balance_amount["amount"]),
|
||||||
balance_type=balance["balanceType"],
|
currency=balance_amount["currency"],
|
||||||
last_change_date=balance.get("lastChangeDateTime")
|
balance_type=balance["balanceType"],
|
||||||
))
|
last_change_date=balance.get("lastChangeDateTime"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
success=True,
|
success=True,
|
||||||
data=balances,
|
data=balances,
|
||||||
message=f"Retrieved {len(balances)} balances for account {account_id}"
|
message=f"Retrieved {len(balances)} balances for account {account_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get balances for account {account_id}: {e}")
|
logger.error(f"Failed to get balances for account {account_id}: {e}")
|
||||||
raise HTTPException(status_code=404, detail=f"Failed to get balances: {str(e)}")
|
raise HTTPException(status_code=404, detail=f"Failed to get balances: {str(e)}")
|
||||||
@@ -139,22 +154,26 @@ async def get_account_transactions(
|
|||||||
account_id: str,
|
account_id: str,
|
||||||
limit: Optional[int] = Query(default=100, le=500),
|
limit: Optional[int] = Query(default=100, le=500),
|
||||||
offset: Optional[int] = Query(default=0, ge=0),
|
offset: Optional[int] = Query(default=0, ge=0),
|
||||||
summary_only: bool = Query(default=False, description="Return transaction summaries only")
|
summary_only: bool = Query(
|
||||||
|
default=False, description="Return transaction summaries only"
|
||||||
|
),
|
||||||
) -> APIResponse:
|
) -> APIResponse:
|
||||||
"""Get transactions for a specific account"""
|
"""Get transactions for a specific account"""
|
||||||
try:
|
try:
|
||||||
account_details = await gocardless_service.get_account_details(account_id)
|
account_details = await gocardless_service.get_account_details(account_id)
|
||||||
transactions_data = await gocardless_service.get_account_transactions(account_id)
|
transactions_data = await gocardless_service.get_account_transactions(
|
||||||
|
account_id
|
||||||
|
)
|
||||||
|
|
||||||
# Process transactions
|
# Process transactions
|
||||||
processed_transactions = database_service.process_transactions(
|
processed_transactions = database_service.process_transactions(
|
||||||
account_id, account_details, transactions_data
|
account_id, account_details, transactions_data
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply pagination
|
# Apply pagination
|
||||||
total_transactions = len(processed_transactions)
|
total_transactions = len(processed_transactions)
|
||||||
paginated_transactions = processed_transactions[offset:offset + limit]
|
paginated_transactions = processed_transactions[offset : offset + limit]
|
||||||
|
|
||||||
if summary_only:
|
if summary_only:
|
||||||
# Return simplified transaction summaries
|
# Return simplified transaction summaries
|
||||||
summaries = [
|
summaries = [
|
||||||
@@ -165,7 +184,7 @@ async def get_account_transactions(
|
|||||||
amount=txn["transactionValue"],
|
amount=txn["transactionValue"],
|
||||||
currency=txn["transactionCurrency"],
|
currency=txn["transactionCurrency"],
|
||||||
status=txn["transactionStatus"],
|
status=txn["transactionStatus"],
|
||||||
account_id=txn["accountId"]
|
account_id=txn["accountId"],
|
||||||
)
|
)
|
||||||
for txn in paginated_transactions
|
for txn in paginated_transactions
|
||||||
]
|
]
|
||||||
@@ -183,18 +202,20 @@ async def get_account_transactions(
|
|||||||
transaction_value=txn["transactionValue"],
|
transaction_value=txn["transactionValue"],
|
||||||
transaction_currency=txn["transactionCurrency"],
|
transaction_currency=txn["transactionCurrency"],
|
||||||
transaction_status=txn["transactionStatus"],
|
transaction_status=txn["transactionStatus"],
|
||||||
raw_transaction=txn["rawTransaction"]
|
raw_transaction=txn["rawTransaction"],
|
||||||
)
|
)
|
||||||
for txn in paginated_transactions
|
for txn in paginated_transactions
|
||||||
]
|
]
|
||||||
data = transactions
|
data = transactions
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
success=True,
|
success=True,
|
||||||
data=data,
|
data=data,
|
||||||
message=f"Retrieved {len(data)} transactions (showing {offset + 1}-{offset + len(data)} of {total_transactions})"
|
message=f"Retrieved {len(data)} transactions (showing {offset + 1}-{offset + len(data)} of {total_transactions})",
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get transactions for account {account_id}: {e}")
|
logger.error(f"Failed to get transactions for account {account_id}: {e}")
|
||||||
raise HTTPException(status_code=404, detail=f"Failed to get transactions: {str(e)}")
|
raise HTTPException(
|
||||||
|
status_code=404, detail=f"Failed to get transactions: {str(e)}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -4,10 +4,10 @@ from loguru import logger
|
|||||||
|
|
||||||
from leggend.api.models.common import APIResponse, ErrorResponse
|
from leggend.api.models.common import APIResponse, ErrorResponse
|
||||||
from leggend.api.models.banks import (
|
from leggend.api.models.banks import (
|
||||||
BankInstitution,
|
BankInstitution,
|
||||||
BankConnectionRequest,
|
BankConnectionRequest,
|
||||||
BankRequisition,
|
BankRequisition,
|
||||||
BankConnectionStatus
|
BankConnectionStatus,
|
||||||
)
|
)
|
||||||
from leggend.services.gocardless_service import GoCardlessService
|
from leggend.services.gocardless_service import GoCardlessService
|
||||||
from leggend.utils.gocardless import REQUISITION_STATUS
|
from leggend.utils.gocardless import REQUISITION_STATUS
|
||||||
@@ -18,12 +18,12 @@ gocardless_service = GoCardlessService()
|
|||||||
|
|
||||||
@router.get("/banks/institutions", response_model=APIResponse)
|
@router.get("/banks/institutions", response_model=APIResponse)
|
||||||
async def get_bank_institutions(
|
async def get_bank_institutions(
|
||||||
country: str = Query(default="PT", description="Country code (e.g., PT, ES, FR)")
|
country: str = Query(default="PT", description="Country code (e.g., PT, ES, FR)"),
|
||||||
) -> APIResponse:
|
) -> APIResponse:
|
||||||
"""Get available bank institutions for a country"""
|
"""Get available bank institutions for a country"""
|
||||||
try:
|
try:
|
||||||
institutions_data = await gocardless_service.get_institutions(country)
|
institutions_data = await gocardless_service.get_institutions(country)
|
||||||
|
|
||||||
institutions = [
|
institutions = [
|
||||||
BankInstitution(
|
BankInstitution(
|
||||||
id=inst["id"],
|
id=inst["id"],
|
||||||
@@ -31,20 +31,22 @@ async def get_bank_institutions(
|
|||||||
bic=inst.get("bic"),
|
bic=inst.get("bic"),
|
||||||
transaction_total_days=inst["transaction_total_days"],
|
transaction_total_days=inst["transaction_total_days"],
|
||||||
countries=inst["countries"],
|
countries=inst["countries"],
|
||||||
logo=inst.get("logo")
|
logo=inst.get("logo"),
|
||||||
)
|
)
|
||||||
for inst in institutions_data
|
for inst in institutions_data
|
||||||
]
|
]
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
success=True,
|
success=True,
|
||||||
data=institutions,
|
data=institutions,
|
||||||
message=f"Found {len(institutions)} institutions for {country}"
|
message=f"Found {len(institutions)} institutions for {country}",
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get institutions for {country}: {e}")
|
logger.error(f"Failed to get institutions for {country}: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to get institutions: {str(e)}")
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to get institutions: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/banks/connect", response_model=APIResponse)
|
@router.post("/banks/connect", response_model=APIResponse)
|
||||||
@@ -52,28 +54,29 @@ async def connect_to_bank(request: BankConnectionRequest) -> APIResponse:
|
|||||||
"""Create a connection to a bank (requisition)"""
|
"""Create a connection to a bank (requisition)"""
|
||||||
try:
|
try:
|
||||||
requisition_data = await gocardless_service.create_requisition(
|
requisition_data = await gocardless_service.create_requisition(
|
||||||
request.institution_id,
|
request.institution_id, request.redirect_url
|
||||||
request.redirect_url
|
|
||||||
)
|
)
|
||||||
|
|
||||||
requisition = BankRequisition(
|
requisition = BankRequisition(
|
||||||
id=requisition_data["id"],
|
id=requisition_data["id"],
|
||||||
institution_id=requisition_data["institution_id"],
|
institution_id=requisition_data["institution_id"],
|
||||||
status=requisition_data["status"],
|
status=requisition_data["status"],
|
||||||
created=requisition_data["created"],
|
created=requisition_data["created"],
|
||||||
link=requisition_data["link"],
|
link=requisition_data["link"],
|
||||||
accounts=requisition_data.get("accounts", [])
|
accounts=requisition_data.get("accounts", []),
|
||||||
)
|
)
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
success=True,
|
success=True,
|
||||||
data=requisition,
|
data=requisition,
|
||||||
message=f"Bank connection created. Please visit the link to authorize."
|
message=f"Bank connection created. Please visit the link to authorize.",
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to connect to bank {request.institution_id}: {e}")
|
logger.error(f"Failed to connect to bank {request.institution_id}: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to connect to bank: {str(e)}")
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to connect to bank: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/banks/status", response_model=APIResponse)
|
@router.get("/banks/status", response_model=APIResponse)
|
||||||
@@ -81,31 +84,37 @@ async def get_bank_connections_status() -> APIResponse:
|
|||||||
"""Get status of all bank connections"""
|
"""Get status of all bank connections"""
|
||||||
try:
|
try:
|
||||||
requisitions_data = await gocardless_service.get_requisitions()
|
requisitions_data = await gocardless_service.get_requisitions()
|
||||||
|
|
||||||
connections = []
|
connections = []
|
||||||
for req in requisitions_data.get("results", []):
|
for req in requisitions_data.get("results", []):
|
||||||
status = req["status"]
|
status = req["status"]
|
||||||
status_display = REQUISITION_STATUS.get(status, "UNKNOWN")
|
status_display = REQUISITION_STATUS.get(status, "UNKNOWN")
|
||||||
|
|
||||||
connections.append(BankConnectionStatus(
|
connections.append(
|
||||||
bank_id=req["institution_id"],
|
BankConnectionStatus(
|
||||||
bank_name=req["institution_id"], # Could be enhanced with actual bank names
|
bank_id=req["institution_id"],
|
||||||
status=status,
|
bank_name=req[
|
||||||
status_display=status_display,
|
"institution_id"
|
||||||
created_at=req["created"],
|
], # Could be enhanced with actual bank names
|
||||||
requisition_id=req["id"],
|
status=status,
|
||||||
accounts_count=len(req.get("accounts", []))
|
status_display=status_display,
|
||||||
))
|
created_at=req["created"],
|
||||||
|
requisition_id=req["id"],
|
||||||
|
accounts_count=len(req.get("accounts", [])),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
success=True,
|
success=True,
|
||||||
data=connections,
|
data=connections,
|
||||||
message=f"Found {len(connections)} bank connections"
|
message=f"Found {len(connections)} bank connections",
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get bank connection status: {e}")
|
logger.error(f"Failed to get bank connection status: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to get bank status: {str(e)}")
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to get bank status: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/banks/connections/{requisition_id}", response_model=APIResponse)
|
@router.delete("/banks/connections/{requisition_id}", response_model=APIResponse)
|
||||||
@@ -116,12 +125,14 @@ async def delete_bank_connection(requisition_id: str) -> APIResponse:
|
|||||||
# For now, return success
|
# For now, return success
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
success=True,
|
success=True,
|
||||||
message=f"Bank connection {requisition_id} deleted successfully"
|
message=f"Bank connection {requisition_id} deleted successfully",
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to delete bank connection {requisition_id}: {e}")
|
logger.error(f"Failed to delete bank connection {requisition_id}: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to delete connection: {str(e)}")
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to delete connection: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/banks/countries", response_model=APIResponse)
|
@router.get("/banks/countries", response_model=APIResponse)
|
||||||
@@ -160,9 +171,9 @@ async def get_supported_countries() -> APIResponse:
|
|||||||
{"code": "SE", "name": "Sweden"},
|
{"code": "SE", "name": "Sweden"},
|
||||||
{"code": "GB", "name": "United Kingdom"},
|
{"code": "GB", "name": "United Kingdom"},
|
||||||
]
|
]
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
success=True,
|
success=True,
|
||||||
data=countries,
|
data=countries,
|
||||||
message="Supported countries retrieved successfully"
|
message="Supported countries retrieved successfully",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ from loguru import logger
|
|||||||
|
|
||||||
from leggend.api.models.common import APIResponse
|
from leggend.api.models.common import APIResponse
|
||||||
from leggend.api.models.notifications import (
|
from leggend.api.models.notifications import (
|
||||||
NotificationSettings,
|
NotificationSettings,
|
||||||
NotificationTest,
|
NotificationTest,
|
||||||
DiscordConfig,
|
DiscordConfig,
|
||||||
TelegramConfig,
|
TelegramConfig,
|
||||||
NotificationFilters
|
NotificationFilters,
|
||||||
)
|
)
|
||||||
from leggend.services.notification_service import NotificationService
|
from leggend.services.notification_service import NotificationService
|
||||||
from leggend.config import config
|
from leggend.config import config
|
||||||
@@ -23,38 +23,44 @@ async def get_notification_settings() -> APIResponse:
|
|||||||
try:
|
try:
|
||||||
notifications_config = config.notifications_config
|
notifications_config = config.notifications_config
|
||||||
filters_config = config.filters_config
|
filters_config = config.filters_config
|
||||||
|
|
||||||
# Build response safely without exposing secrets
|
# Build response safely without exposing secrets
|
||||||
discord_config = notifications_config.get("discord", {})
|
discord_config = notifications_config.get("discord", {})
|
||||||
telegram_config = notifications_config.get("telegram", {})
|
telegram_config = notifications_config.get("telegram", {})
|
||||||
|
|
||||||
settings = NotificationSettings(
|
settings = NotificationSettings(
|
||||||
discord=DiscordConfig(
|
discord=DiscordConfig(
|
||||||
webhook="***" if discord_config.get("webhook") else "",
|
webhook="***" if discord_config.get("webhook") else "",
|
||||||
enabled=discord_config.get("enabled", True)
|
enabled=discord_config.get("enabled", True),
|
||||||
) if discord_config.get("webhook") else None,
|
)
|
||||||
|
if discord_config.get("webhook")
|
||||||
|
else None,
|
||||||
telegram=TelegramConfig(
|
telegram=TelegramConfig(
|
||||||
token="***" if telegram_config.get("token") else "",
|
token="***" if telegram_config.get("token") else "",
|
||||||
chat_id=telegram_config.get("chat_id", 0),
|
chat_id=telegram_config.get("chat_id", 0),
|
||||||
enabled=telegram_config.get("enabled", True)
|
enabled=telegram_config.get("enabled", True),
|
||||||
) if telegram_config.get("token") else None,
|
)
|
||||||
|
if telegram_config.get("token")
|
||||||
|
else None,
|
||||||
filters=NotificationFilters(
|
filters=NotificationFilters(
|
||||||
case_insensitive=filters_config.get("case-insensitive", {}),
|
case_insensitive=filters_config.get("case-insensitive", {}),
|
||||||
case_sensitive=filters_config.get("case-sensitive"),
|
case_sensitive=filters_config.get("case-sensitive"),
|
||||||
amount_threshold=filters_config.get("amount_threshold"),
|
amount_threshold=filters_config.get("amount_threshold"),
|
||||||
keywords=filters_config.get("keywords", [])
|
keywords=filters_config.get("keywords", []),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
success=True,
|
success=True,
|
||||||
data=settings,
|
data=settings,
|
||||||
message="Notification settings retrieved successfully"
|
message="Notification settings retrieved successfully",
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get notification settings: {e}")
|
logger.error(f"Failed to get notification settings: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to get notification settings: {str(e)}")
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to get notification settings: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.put("/notifications/settings", response_model=APIResponse)
|
@router.put("/notifications/settings", response_model=APIResponse)
|
||||||
@@ -63,20 +69,20 @@ async def update_notification_settings(settings: NotificationSettings) -> APIRes
|
|||||||
try:
|
try:
|
||||||
# Update notifications config
|
# Update notifications config
|
||||||
notifications_config = {}
|
notifications_config = {}
|
||||||
|
|
||||||
if settings.discord:
|
if settings.discord:
|
||||||
notifications_config["discord"] = {
|
notifications_config["discord"] = {
|
||||||
"webhook": settings.discord.webhook,
|
"webhook": settings.discord.webhook,
|
||||||
"enabled": settings.discord.enabled
|
"enabled": settings.discord.enabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
if settings.telegram:
|
if settings.telegram:
|
||||||
notifications_config["telegram"] = {
|
notifications_config["telegram"] = {
|
||||||
"token": settings.telegram.token,
|
"token": settings.telegram.token,
|
||||||
"chat_id": settings.telegram.chat_id,
|
"chat_id": settings.telegram.chat_id,
|
||||||
"enabled": settings.telegram.enabled
|
"enabled": settings.telegram.enabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Update filters config
|
# Update filters config
|
||||||
filters_config = {}
|
filters_config = {}
|
||||||
if settings.filters.case_insensitive:
|
if settings.filters.case_insensitive:
|
||||||
@@ -87,22 +93,24 @@ async def update_notification_settings(settings: NotificationSettings) -> APIRes
|
|||||||
filters_config["amount_threshold"] = settings.filters.amount_threshold
|
filters_config["amount_threshold"] = settings.filters.amount_threshold
|
||||||
if settings.filters.keywords:
|
if settings.filters.keywords:
|
||||||
filters_config["keywords"] = settings.filters.keywords
|
filters_config["keywords"] = settings.filters.keywords
|
||||||
|
|
||||||
# Save to config
|
# Save to config
|
||||||
if notifications_config:
|
if notifications_config:
|
||||||
config.update_section("notifications", notifications_config)
|
config.update_section("notifications", notifications_config)
|
||||||
if filters_config:
|
if filters_config:
|
||||||
config.update_section("filters", filters_config)
|
config.update_section("filters", filters_config)
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
success=True,
|
success=True,
|
||||||
data={"updated": True},
|
data={"updated": True},
|
||||||
message="Notification settings updated successfully"
|
message="Notification settings updated successfully",
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to update notification settings: {e}")
|
logger.error(f"Failed to update notification settings: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to update notification settings: {str(e)}")
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to update notification settings: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/notifications/test", response_model=APIResponse)
|
@router.post("/notifications/test", response_model=APIResponse)
|
||||||
@@ -110,25 +118,26 @@ async def test_notification(test_request: NotificationTest) -> APIResponse:
|
|||||||
"""Send a test notification"""
|
"""Send a test notification"""
|
||||||
try:
|
try:
|
||||||
success = await notification_service.send_test_notification(
|
success = await notification_service.send_test_notification(
|
||||||
test_request.service,
|
test_request.service, test_request.message
|
||||||
test_request.message
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
success=True,
|
success=True,
|
||||||
data={"sent": True},
|
data={"sent": True},
|
||||||
message=f"Test notification sent to {test_request.service} successfully"
|
message=f"Test notification sent to {test_request.service} successfully",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
success=False,
|
success=False,
|
||||||
message=f"Failed to send test notification to {test_request.service}"
|
message=f"Failed to send test notification to {test_request.service}",
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to send test notification: {e}")
|
logger.error(f"Failed to send test notification: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to send test notification: {str(e)}")
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to send test notification: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/notifications/services", response_model=APIResponse)
|
@router.get("/notifications/services", response_model=APIResponse)
|
||||||
@@ -136,37 +145,41 @@ async def get_notification_services() -> APIResponse:
|
|||||||
"""Get available notification services and their status"""
|
"""Get available notification services and their status"""
|
||||||
try:
|
try:
|
||||||
notifications_config = config.notifications_config
|
notifications_config = config.notifications_config
|
||||||
|
|
||||||
services = {
|
services = {
|
||||||
"discord": {
|
"discord": {
|
||||||
"name": "Discord",
|
"name": "Discord",
|
||||||
"enabled": bool(notifications_config.get("discord", {}).get("webhook")),
|
"enabled": bool(notifications_config.get("discord", {}).get("webhook")),
|
||||||
"configured": bool(notifications_config.get("discord", {}).get("webhook")),
|
"configured": bool(
|
||||||
"active": notifications_config.get("discord", {}).get("enabled", True)
|
notifications_config.get("discord", {}).get("webhook")
|
||||||
|
),
|
||||||
|
"active": notifications_config.get("discord", {}).get("enabled", True),
|
||||||
},
|
},
|
||||||
"telegram": {
|
"telegram": {
|
||||||
"name": "Telegram",
|
"name": "Telegram",
|
||||||
"enabled": bool(
|
"enabled": bool(
|
||||||
notifications_config.get("telegram", {}).get("token") and
|
notifications_config.get("telegram", {}).get("token")
|
||||||
notifications_config.get("telegram", {}).get("chat_id")
|
and notifications_config.get("telegram", {}).get("chat_id")
|
||||||
),
|
),
|
||||||
"configured": bool(
|
"configured": bool(
|
||||||
notifications_config.get("telegram", {}).get("token") and
|
notifications_config.get("telegram", {}).get("token")
|
||||||
notifications_config.get("telegram", {}).get("chat_id")
|
and notifications_config.get("telegram", {}).get("chat_id")
|
||||||
),
|
),
|
||||||
"active": notifications_config.get("telegram", {}).get("enabled", True)
|
"active": notifications_config.get("telegram", {}).get("enabled", True),
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
success=True,
|
success=True,
|
||||||
data=services,
|
data=services,
|
||||||
message="Notification services status retrieved successfully"
|
message="Notification services status retrieved successfully",
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get notification services: {e}")
|
logger.error(f"Failed to get notification services: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to get notification services: {str(e)}")
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to get notification services: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/notifications/settings/{service}", response_model=APIResponse)
|
@router.delete("/notifications/settings/{service}", response_model=APIResponse)
|
||||||
@@ -174,19 +187,23 @@ async def delete_notification_service(service: str) -> APIResponse:
|
|||||||
"""Delete/disable a notification service"""
|
"""Delete/disable a notification service"""
|
||||||
try:
|
try:
|
||||||
if service not in ["discord", "telegram"]:
|
if service not in ["discord", "telegram"]:
|
||||||
raise HTTPException(status_code=400, detail="Service must be 'discord' or 'telegram'")
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Service must be 'discord' or 'telegram'"
|
||||||
|
)
|
||||||
|
|
||||||
notifications_config = config.notifications_config.copy()
|
notifications_config = config.notifications_config.copy()
|
||||||
if service in notifications_config:
|
if service in notifications_config:
|
||||||
del notifications_config[service]
|
del notifications_config[service]
|
||||||
config.update_section("notifications", notifications_config)
|
config.update_section("notifications", notifications_config)
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
success=True,
|
success=True,
|
||||||
data={"deleted": service},
|
data={"deleted": service},
|
||||||
message=f"{service.capitalize()} notification service deleted successfully"
|
message=f"{service.capitalize()} notification service deleted successfully",
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to delete notification service {service}: {e}")
|
logger.error(f"Failed to delete notification service {service}: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to delete notification service: {str(e)}")
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to delete notification service: {str(e)}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -17,27 +17,26 @@ async def get_sync_status() -> APIResponse:
|
|||||||
"""Get current sync status"""
|
"""Get current sync status"""
|
||||||
try:
|
try:
|
||||||
status = await sync_service.get_sync_status()
|
status = await sync_service.get_sync_status()
|
||||||
|
|
||||||
# Add scheduler information
|
# Add scheduler information
|
||||||
next_sync_time = scheduler.get_next_sync_time()
|
next_sync_time = scheduler.get_next_sync_time()
|
||||||
if next_sync_time:
|
if next_sync_time:
|
||||||
status.next_sync = next_sync_time
|
status.next_sync = next_sync_time
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
success=True,
|
success=True, data=status, message="Sync status retrieved successfully"
|
||||||
data=status,
|
|
||||||
message="Sync status retrieved successfully"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get sync status: {e}")
|
logger.error(f"Failed to get sync status: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to get sync status: {str(e)}")
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to get sync status: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sync", response_model=APIResponse)
|
@router.post("/sync", response_model=APIResponse)
|
||||||
async def trigger_sync(
|
async def trigger_sync(
|
||||||
background_tasks: BackgroundTasks,
|
background_tasks: BackgroundTasks, sync_request: Optional[SyncRequest] = None
|
||||||
sync_request: Optional[SyncRequest] = None
|
|
||||||
) -> APIResponse:
|
) -> APIResponse:
|
||||||
"""Trigger a manual sync operation"""
|
"""Trigger a manual sync operation"""
|
||||||
try:
|
try:
|
||||||
@@ -46,32 +45,37 @@ async def trigger_sync(
|
|||||||
if status.is_running and not (sync_request and sync_request.force):
|
if status.is_running and not (sync_request and sync_request.force):
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
success=False,
|
success=False,
|
||||||
message="Sync is already running. Use 'force: true' to override."
|
message="Sync is already running. Use 'force: true' to override.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Determine what to sync
|
# Determine what to sync
|
||||||
if sync_request and sync_request.account_ids:
|
if sync_request and sync_request.account_ids:
|
||||||
# Sync specific accounts in background
|
# Sync specific accounts in background
|
||||||
background_tasks.add_task(
|
background_tasks.add_task(
|
||||||
sync_service.sync_specific_accounts,
|
sync_service.sync_specific_accounts,
|
||||||
sync_request.account_ids,
|
sync_request.account_ids,
|
||||||
sync_request.force if sync_request else False
|
sync_request.force if sync_request else False,
|
||||||
|
)
|
||||||
|
message = (
|
||||||
|
f"Started sync for {len(sync_request.account_ids)} specific accounts"
|
||||||
)
|
)
|
||||||
message = f"Started sync for {len(sync_request.account_ids)} specific accounts"
|
|
||||||
else:
|
else:
|
||||||
# Sync all accounts in background
|
# Sync all accounts in background
|
||||||
background_tasks.add_task(
|
background_tasks.add_task(
|
||||||
sync_service.sync_all_accounts,
|
sync_service.sync_all_accounts,
|
||||||
sync_request.force if sync_request else False
|
sync_request.force if sync_request else False,
|
||||||
)
|
)
|
||||||
message = "Started sync for all accounts"
|
message = "Started sync for all accounts"
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
success=True,
|
success=True,
|
||||||
data={"sync_started": True, "force": sync_request.force if sync_request else False},
|
data={
|
||||||
message=message
|
"sync_started": True,
|
||||||
|
"force": sync_request.force if sync_request else False,
|
||||||
|
},
|
||||||
|
message=message,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to trigger sync: {e}")
|
logger.error(f"Failed to trigger sync: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to trigger sync: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Failed to trigger sync: {str(e)}")
|
||||||
@@ -83,20 +87,21 @@ async def sync_now(sync_request: Optional[SyncRequest] = None) -> APIResponse:
|
|||||||
try:
|
try:
|
||||||
if sync_request and sync_request.account_ids:
|
if sync_request and sync_request.account_ids:
|
||||||
result = await sync_service.sync_specific_accounts(
|
result = await sync_service.sync_specific_accounts(
|
||||||
sync_request.account_ids,
|
sync_request.account_ids, sync_request.force
|
||||||
sync_request.force
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = await sync_service.sync_all_accounts(
|
result = await sync_service.sync_all_accounts(
|
||||||
sync_request.force if sync_request else False
|
sync_request.force if sync_request else False
|
||||||
)
|
)
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
success=result.success,
|
success=result.success,
|
||||||
data=result,
|
data=result,
|
||||||
message="Sync completed" if result.success else f"Sync failed with {len(result.errors)} errors"
|
message="Sync completed"
|
||||||
|
if result.success
|
||||||
|
else f"Sync failed with {len(result.errors)} errors",
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to run sync: {e}")
|
logger.error(f"Failed to run sync: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to run sync: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Failed to run sync: {str(e)}")
|
||||||
@@ -108,22 +113,28 @@ async def get_scheduler_config() -> APIResponse:
|
|||||||
try:
|
try:
|
||||||
scheduler_config = config.scheduler_config
|
scheduler_config = config.scheduler_config
|
||||||
next_sync_time = scheduler.get_next_sync_time()
|
next_sync_time = scheduler.get_next_sync_time()
|
||||||
|
|
||||||
response_data = {
|
response_data = {
|
||||||
**scheduler_config,
|
**scheduler_config,
|
||||||
"next_scheduled_sync": next_sync_time.isoformat() if next_sync_time else None,
|
"next_scheduled_sync": next_sync_time.isoformat()
|
||||||
"is_running": scheduler.scheduler.running if hasattr(scheduler, 'scheduler') else False
|
if next_sync_time
|
||||||
|
else None,
|
||||||
|
"is_running": scheduler.scheduler.running
|
||||||
|
if hasattr(scheduler, "scheduler")
|
||||||
|
else False,
|
||||||
}
|
}
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
success=True,
|
success=True,
|
||||||
data=response_data,
|
data=response_data,
|
||||||
message="Scheduler configuration retrieved successfully"
|
message="Scheduler configuration retrieved successfully",
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get scheduler config: {e}")
|
logger.error(f"Failed to get scheduler config: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to get scheduler config: {str(e)}")
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to get scheduler config: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.put("/sync/scheduler", response_model=APIResponse)
|
@router.put("/sync/scheduler", response_model=APIResponse)
|
||||||
@@ -135,26 +146,32 @@ async def update_scheduler_config(scheduler_config: SchedulerConfig) -> APIRespo
|
|||||||
try:
|
try:
|
||||||
cron_parts = scheduler_config.cron.split()
|
cron_parts = scheduler_config.cron.split()
|
||||||
if len(cron_parts) != 5:
|
if len(cron_parts) != 5:
|
||||||
raise ValueError("Cron expression must have 5 parts: minute hour day month day_of_week")
|
raise ValueError(
|
||||||
|
"Cron expression must have 5 parts: minute hour day month day_of_week"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid cron expression: {str(e)}")
|
raise HTTPException(
|
||||||
|
status_code=400, detail=f"Invalid cron expression: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
# Update configuration
|
# Update configuration
|
||||||
schedule_data = scheduler_config.dict(exclude_none=True)
|
schedule_data = scheduler_config.dict(exclude_none=True)
|
||||||
config.update_section("scheduler", {"sync": schedule_data})
|
config.update_section("scheduler", {"sync": schedule_data})
|
||||||
|
|
||||||
# Reschedule the job
|
# Reschedule the job
|
||||||
scheduler.reschedule_sync(schedule_data)
|
scheduler.reschedule_sync(schedule_data)
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
success=True,
|
success=True,
|
||||||
data=schedule_data,
|
data=schedule_data,
|
||||||
message="Scheduler configuration updated successfully"
|
message="Scheduler configuration updated successfully",
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to update scheduler config: {e}")
|
logger.error(f"Failed to update scheduler config: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to update scheduler config: {str(e)}")
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to update scheduler config: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sync/scheduler/start", response_model=APIResponse)
|
@router.post("/sync/scheduler/start", response_model=APIResponse)
|
||||||
@@ -163,37 +180,29 @@ async def start_scheduler() -> APIResponse:
|
|||||||
try:
|
try:
|
||||||
if not scheduler.scheduler.running:
|
if not scheduler.scheduler.running:
|
||||||
scheduler.start()
|
scheduler.start()
|
||||||
return APIResponse(
|
return APIResponse(success=True, message="Scheduler started successfully")
|
||||||
success=True,
|
|
||||||
message="Scheduler started successfully"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return APIResponse(
|
return APIResponse(success=True, message="Scheduler is already running")
|
||||||
success=True,
|
|
||||||
message="Scheduler is already running"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to start scheduler: {e}")
|
logger.error(f"Failed to start scheduler: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to start scheduler: {str(e)}")
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to start scheduler: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sync/scheduler/stop", response_model=APIResponse)
|
@router.post("/sync/scheduler/stop", response_model=APIResponse)
|
||||||
async def stop_scheduler() -> APIResponse:
|
async def stop_scheduler() -> APIResponse:
|
||||||
"""Stop the background scheduler"""
|
"""Stop the background scheduler"""
|
||||||
try:
|
try:
|
||||||
if scheduler.scheduler.running:
|
if scheduler.scheduler.running:
|
||||||
scheduler.shutdown()
|
scheduler.shutdown()
|
||||||
return APIResponse(
|
return APIResponse(success=True, message="Scheduler stopped successfully")
|
||||||
success=True,
|
|
||||||
message="Scheduler stopped successfully"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return APIResponse(
|
return APIResponse(success=True, message="Scheduler is already stopped")
|
||||||
success=True,
|
|
||||||
message="Scheduler is already stopped"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to stop scheduler: {e}")
|
logger.error(f"Failed to stop scheduler: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to stop scheduler: {str(e)}")
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to stop scheduler: {str(e)}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -17,95 +17,111 @@ database_service = DatabaseService()
|
|||||||
async def get_all_transactions(
|
async def get_all_transactions(
|
||||||
limit: Optional[int] = Query(default=100, le=500),
|
limit: Optional[int] = Query(default=100, le=500),
|
||||||
offset: Optional[int] = Query(default=0, ge=0),
|
offset: Optional[int] = Query(default=0, ge=0),
|
||||||
summary_only: bool = Query(default=True, description="Return transaction summaries only"),
|
summary_only: bool = Query(
|
||||||
date_from: Optional[str] = Query(default=None, description="Filter from date (YYYY-MM-DD)"),
|
default=True, description="Return transaction summaries only"
|
||||||
date_to: Optional[str] = Query(default=None, description="Filter to date (YYYY-MM-DD)"),
|
),
|
||||||
min_amount: Optional[float] = Query(default=None, description="Minimum transaction amount"),
|
date_from: Optional[str] = Query(
|
||||||
max_amount: Optional[float] = Query(default=None, description="Maximum transaction amount"),
|
default=None, description="Filter from date (YYYY-MM-DD)"
|
||||||
search: Optional[str] = Query(default=None, description="Search in transaction descriptions"),
|
),
|
||||||
account_id: Optional[str] = Query(default=None, description="Filter by account ID")
|
date_to: Optional[str] = Query(
|
||||||
|
default=None, description="Filter to date (YYYY-MM-DD)"
|
||||||
|
),
|
||||||
|
min_amount: Optional[float] = Query(
|
||||||
|
default=None, description="Minimum transaction amount"
|
||||||
|
),
|
||||||
|
max_amount: Optional[float] = Query(
|
||||||
|
default=None, description="Maximum transaction amount"
|
||||||
|
),
|
||||||
|
search: Optional[str] = Query(
|
||||||
|
default=None, description="Search in transaction descriptions"
|
||||||
|
),
|
||||||
|
account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
|
||||||
) -> APIResponse:
|
) -> APIResponse:
|
||||||
"""Get all transactions across all accounts with filtering options"""
|
"""Get all transactions across all accounts with filtering options"""
|
||||||
try:
|
try:
|
||||||
# Get all requisitions and accounts
|
# Get all requisitions and accounts
|
||||||
requisitions_data = await gocardless_service.get_requisitions()
|
requisitions_data = await gocardless_service.get_requisitions()
|
||||||
all_accounts = set()
|
all_accounts = set()
|
||||||
|
|
||||||
for req in requisitions_data.get("results", []):
|
for req in requisitions_data.get("results", []):
|
||||||
all_accounts.update(req.get("accounts", []))
|
all_accounts.update(req.get("accounts", []))
|
||||||
|
|
||||||
# Filter by specific account if requested
|
# Filter by specific account if requested
|
||||||
if account_id:
|
if account_id:
|
||||||
if account_id not in all_accounts:
|
if account_id not in all_accounts:
|
||||||
raise HTTPException(status_code=404, detail="Account not found")
|
raise HTTPException(status_code=404, detail="Account not found")
|
||||||
all_accounts = {account_id}
|
all_accounts = {account_id}
|
||||||
|
|
||||||
all_transactions = []
|
all_transactions = []
|
||||||
|
|
||||||
# Collect transactions from all accounts
|
# Collect transactions from all accounts
|
||||||
for acc_id in all_accounts:
|
for acc_id in all_accounts:
|
||||||
try:
|
try:
|
||||||
account_details = await gocardless_service.get_account_details(acc_id)
|
account_details = await gocardless_service.get_account_details(acc_id)
|
||||||
transactions_data = await gocardless_service.get_account_transactions(acc_id)
|
transactions_data = await gocardless_service.get_account_transactions(
|
||||||
|
acc_id
|
||||||
|
)
|
||||||
|
|
||||||
processed_transactions = database_service.process_transactions(
|
processed_transactions = database_service.process_transactions(
|
||||||
acc_id, account_details, transactions_data
|
acc_id, account_details, transactions_data
|
||||||
)
|
)
|
||||||
all_transactions.extend(processed_transactions)
|
all_transactions.extend(processed_transactions)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get transactions for account {acc_id}: {e}")
|
logger.error(f"Failed to get transactions for account {acc_id}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Apply filters
|
# Apply filters
|
||||||
filtered_transactions = all_transactions
|
filtered_transactions = all_transactions
|
||||||
|
|
||||||
# Date range filter
|
# Date range filter
|
||||||
if date_from:
|
if date_from:
|
||||||
from_date = datetime.fromisoformat(date_from)
|
from_date = datetime.fromisoformat(date_from)
|
||||||
filtered_transactions = [
|
filtered_transactions = [
|
||||||
txn for txn in filtered_transactions
|
txn
|
||||||
|
for txn in filtered_transactions
|
||||||
if txn["transactionDate"] >= from_date
|
if txn["transactionDate"] >= from_date
|
||||||
]
|
]
|
||||||
|
|
||||||
if date_to:
|
if date_to:
|
||||||
to_date = datetime.fromisoformat(date_to)
|
to_date = datetime.fromisoformat(date_to)
|
||||||
filtered_transactions = [
|
filtered_transactions = [
|
||||||
txn for txn in filtered_transactions
|
txn
|
||||||
|
for txn in filtered_transactions
|
||||||
if txn["transactionDate"] <= to_date
|
if txn["transactionDate"] <= to_date
|
||||||
]
|
]
|
||||||
|
|
||||||
# Amount filters
|
# Amount filters
|
||||||
if min_amount is not None:
|
if min_amount is not None:
|
||||||
filtered_transactions = [
|
filtered_transactions = [
|
||||||
txn for txn in filtered_transactions
|
txn
|
||||||
|
for txn in filtered_transactions
|
||||||
if txn["transactionValue"] >= min_amount
|
if txn["transactionValue"] >= min_amount
|
||||||
]
|
]
|
||||||
|
|
||||||
if max_amount is not None:
|
if max_amount is not None:
|
||||||
filtered_transactions = [
|
filtered_transactions = [
|
||||||
txn for txn in filtered_transactions
|
txn
|
||||||
|
for txn in filtered_transactions
|
||||||
if txn["transactionValue"] <= max_amount
|
if txn["transactionValue"] <= max_amount
|
||||||
]
|
]
|
||||||
|
|
||||||
# Search filter
|
# Search filter
|
||||||
if search:
|
if search:
|
||||||
search_lower = search.lower()
|
search_lower = search.lower()
|
||||||
filtered_transactions = [
|
filtered_transactions = [
|
||||||
txn for txn in filtered_transactions
|
txn
|
||||||
|
for txn in filtered_transactions
|
||||||
if search_lower in txn["description"].lower()
|
if search_lower in txn["description"].lower()
|
||||||
]
|
]
|
||||||
|
|
||||||
# Sort by date (newest first)
|
# Sort by date (newest first)
|
||||||
filtered_transactions.sort(
|
filtered_transactions.sort(key=lambda x: x["transactionDate"], reverse=True)
|
||||||
key=lambda x: x["transactionDate"],
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply pagination
|
# Apply pagination
|
||||||
total_transactions = len(filtered_transactions)
|
total_transactions = len(filtered_transactions)
|
||||||
paginated_transactions = filtered_transactions[offset:offset + limit]
|
paginated_transactions = filtered_transactions[offset : offset + limit]
|
||||||
|
|
||||||
if summary_only:
|
if summary_only:
|
||||||
# Return simplified transaction summaries
|
# Return simplified transaction summaries
|
||||||
data = [
|
data = [
|
||||||
@@ -116,7 +132,7 @@ async def get_all_transactions(
|
|||||||
amount=txn["transactionValue"],
|
amount=txn["transactionValue"],
|
||||||
currency=txn["transactionCurrency"],
|
currency=txn["transactionCurrency"],
|
||||||
status=txn["transactionStatus"],
|
status=txn["transactionStatus"],
|
||||||
account_id=txn["accountId"]
|
account_id=txn["accountId"],
|
||||||
)
|
)
|
||||||
for txn in paginated_transactions
|
for txn in paginated_transactions
|
||||||
]
|
]
|
||||||
@@ -133,86 +149,99 @@ async def get_all_transactions(
|
|||||||
transaction_value=txn["transactionValue"],
|
transaction_value=txn["transactionValue"],
|
||||||
transaction_currency=txn["transactionCurrency"],
|
transaction_currency=txn["transactionCurrency"],
|
||||||
transaction_status=txn["transactionStatus"],
|
transaction_status=txn["transactionStatus"],
|
||||||
raw_transaction=txn["rawTransaction"]
|
raw_transaction=txn["rawTransaction"],
|
||||||
)
|
)
|
||||||
for txn in paginated_transactions
|
for txn in paginated_transactions
|
||||||
]
|
]
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
success=True,
|
success=True,
|
||||||
data=data,
|
data=data,
|
||||||
message=f"Retrieved {len(data)} transactions (showing {offset + 1}-{offset + len(data)} of {total_transactions})"
|
message=f"Retrieved {len(data)} transactions (showing {offset + 1}-{offset + len(data)} of {total_transactions})",
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get transactions: {e}")
|
logger.error(f"Failed to get transactions: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to get transactions: {str(e)}")
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to get transactions: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/transactions/stats", response_model=APIResponse)
|
@router.get("/transactions/stats", response_model=APIResponse)
|
||||||
async def get_transaction_stats(
|
async def get_transaction_stats(
|
||||||
days: int = Query(default=30, description="Number of days to include in stats"),
|
days: int = Query(default=30, description="Number of days to include in stats"),
|
||||||
account_id: Optional[str] = Query(default=None, description="Filter by account ID")
|
account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
|
||||||
) -> APIResponse:
|
) -> APIResponse:
|
||||||
"""Get transaction statistics for the last N days"""
|
"""Get transaction statistics for the last N days"""
|
||||||
try:
|
try:
|
||||||
# Date range for stats
|
# Date range for stats
|
||||||
end_date = datetime.now()
|
end_date = datetime.now()
|
||||||
start_date = end_date - timedelta(days=days)
|
start_date = end_date - timedelta(days=days)
|
||||||
|
|
||||||
# Get all transactions (reuse the existing endpoint logic)
|
# Get all transactions (reuse the existing endpoint logic)
|
||||||
# This is a simplified implementation - in practice you might want to optimize this
|
# This is a simplified implementation - in practice you might want to optimize this
|
||||||
requisitions_data = await gocardless_service.get_requisitions()
|
requisitions_data = await gocardless_service.get_requisitions()
|
||||||
all_accounts = set()
|
all_accounts = set()
|
||||||
|
|
||||||
for req in requisitions_data.get("results", []):
|
for req in requisitions_data.get("results", []):
|
||||||
all_accounts.update(req.get("accounts", []))
|
all_accounts.update(req.get("accounts", []))
|
||||||
|
|
||||||
if account_id:
|
if account_id:
|
||||||
if account_id not in all_accounts:
|
if account_id not in all_accounts:
|
||||||
raise HTTPException(status_code=404, detail="Account not found")
|
raise HTTPException(status_code=404, detail="Account not found")
|
||||||
all_accounts = {account_id}
|
all_accounts = {account_id}
|
||||||
|
|
||||||
all_transactions = []
|
all_transactions = []
|
||||||
|
|
||||||
for acc_id in all_accounts:
|
for acc_id in all_accounts:
|
||||||
try:
|
try:
|
||||||
account_details = await gocardless_service.get_account_details(acc_id)
|
account_details = await gocardless_service.get_account_details(acc_id)
|
||||||
transactions_data = await gocardless_service.get_account_transactions(acc_id)
|
transactions_data = await gocardless_service.get_account_transactions(
|
||||||
|
acc_id
|
||||||
|
)
|
||||||
|
|
||||||
processed_transactions = database_service.process_transactions(
|
processed_transactions = database_service.process_transactions(
|
||||||
acc_id, account_details, transactions_data
|
acc_id, account_details, transactions_data
|
||||||
)
|
)
|
||||||
all_transactions.extend(processed_transactions)
|
all_transactions.extend(processed_transactions)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get transactions for account {acc_id}: {e}")
|
logger.error(f"Failed to get transactions for account {acc_id}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Filter transactions by date range
|
# Filter transactions by date range
|
||||||
recent_transactions = [
|
recent_transactions = [
|
||||||
txn for txn in all_transactions
|
txn
|
||||||
|
for txn in all_transactions
|
||||||
if start_date <= txn["transactionDate"] <= end_date
|
if start_date <= txn["transactionDate"] <= end_date
|
||||||
]
|
]
|
||||||
|
|
||||||
# Calculate stats
|
# Calculate stats
|
||||||
total_transactions = len(recent_transactions)
|
total_transactions = len(recent_transactions)
|
||||||
total_income = sum(
|
total_income = sum(
|
||||||
txn["transactionValue"]
|
txn["transactionValue"]
|
||||||
for txn in recent_transactions
|
for txn in recent_transactions
|
||||||
if txn["transactionValue"] > 0
|
if txn["transactionValue"] > 0
|
||||||
)
|
)
|
||||||
total_expenses = sum(
|
total_expenses = sum(
|
||||||
abs(txn["transactionValue"])
|
abs(txn["transactionValue"])
|
||||||
for txn in recent_transactions
|
for txn in recent_transactions
|
||||||
if txn["transactionValue"] < 0
|
if txn["transactionValue"] < 0
|
||||||
)
|
)
|
||||||
net_change = total_income - total_expenses
|
net_change = total_income - total_expenses
|
||||||
|
|
||||||
# Count by status
|
# Count by status
|
||||||
booked_count = len([txn for txn in recent_transactions if txn["transactionStatus"] == "booked"])
|
booked_count = len(
|
||||||
pending_count = len([txn for txn in recent_transactions if txn["transactionStatus"] == "pending"])
|
[txn for txn in recent_transactions if txn["transactionStatus"] == "booked"]
|
||||||
|
)
|
||||||
|
pending_count = len(
|
||||||
|
[
|
||||||
|
txn
|
||||||
|
for txn in recent_transactions
|
||||||
|
if txn["transactionStatus"] == "pending"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
stats = {
|
stats = {
|
||||||
"period_days": days,
|
"period_days": days,
|
||||||
"total_transactions": total_transactions,
|
"total_transactions": total_transactions,
|
||||||
@@ -222,17 +251,23 @@ async def get_transaction_stats(
|
|||||||
"total_expenses": round(total_expenses, 2),
|
"total_expenses": round(total_expenses, 2),
|
||||||
"net_change": round(net_change, 2),
|
"net_change": round(net_change, 2),
|
||||||
"average_transaction": round(
|
"average_transaction": round(
|
||||||
sum(txn["transactionValue"] for txn in recent_transactions) / total_transactions, 2
|
sum(txn["transactionValue"] for txn in recent_transactions)
|
||||||
) if total_transactions > 0 else 0,
|
/ total_transactions,
|
||||||
"accounts_included": len(all_accounts)
|
2,
|
||||||
|
)
|
||||||
|
if total_transactions > 0
|
||||||
|
else 0,
|
||||||
|
"accounts_included": len(all_accounts),
|
||||||
}
|
}
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
success=True,
|
success=True,
|
||||||
data=stats,
|
data=stats,
|
||||||
message=f"Transaction statistics for last {days} days"
|
message=f"Transaction statistics for last {days} days",
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get transaction stats: {e}")
|
logger.error(f"Failed to get transaction stats: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to get transaction stats: {str(e)}")
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to get transaction stats: {str(e)}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -4,47 +4,30 @@ from loguru import logger
|
|||||||
|
|
||||||
from leggend.config import config
|
from leggend.config import config
|
||||||
from leggend.services.sync_service import SyncService
|
from leggend.services.sync_service import SyncService
|
||||||
|
from leggend.services.notification_service import NotificationService
|
||||||
|
|
||||||
|
|
||||||
class BackgroundScheduler:
|
class BackgroundScheduler:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.scheduler = AsyncIOScheduler()
|
self.scheduler = AsyncIOScheduler()
|
||||||
self.sync_service = SyncService()
|
self.sync_service = SyncService()
|
||||||
|
self.notification_service = NotificationService()
|
||||||
|
self.max_retries = 3
|
||||||
|
self.retry_delay = 300 # 5 minutes
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Start the scheduler and configure sync jobs based on configuration"""
|
"""Start the scheduler and configure sync jobs based on configuration"""
|
||||||
schedule_config = config.scheduler_config.get("sync", {})
|
schedule_config = config.scheduler_config.get("sync", {})
|
||||||
|
|
||||||
if not schedule_config.get("enabled", True):
|
if not schedule_config.get("enabled", True):
|
||||||
logger.info("Sync scheduling is disabled in configuration")
|
logger.info("Sync scheduling is disabled in configuration")
|
||||||
self.scheduler.start()
|
self.scheduler.start()
|
||||||
return
|
return
|
||||||
|
|
||||||
# Use custom cron expression if provided, otherwise use hour/minute
|
# Parse schedule configuration
|
||||||
if schedule_config.get("cron"):
|
trigger = self._parse_cron_config(schedule_config)
|
||||||
# Parse custom cron expression (e.g., "0 3 * * *" for daily at 3 AM)
|
if not trigger:
|
||||||
try:
|
return
|
||||||
cron_parts = schedule_config["cron"].split()
|
|
||||||
if len(cron_parts) == 5:
|
|
||||||
minute, hour, day, month, day_of_week = cron_parts
|
|
||||||
trigger = CronTrigger(
|
|
||||||
minute=minute,
|
|
||||||
hour=hour,
|
|
||||||
day=day if day != "*" else None,
|
|
||||||
month=month if month != "*" else None,
|
|
||||||
day_of_week=day_of_week if day_of_week != "*" else None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.error(f"Invalid cron expression: {schedule_config['cron']}")
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error parsing cron expression: {e}")
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
# Use hour/minute configuration (default: 3:00 AM daily)
|
|
||||||
hour = schedule_config.get("hour", 3)
|
|
||||||
minute = schedule_config.get("minute", 0)
|
|
||||||
trigger = CronTrigger(hour=hour, minute=minute)
|
|
||||||
|
|
||||||
self.scheduler.add_job(
|
self.scheduler.add_job(
|
||||||
self._run_sync,
|
self._run_sync,
|
||||||
@@ -53,7 +36,7 @@ class BackgroundScheduler:
|
|||||||
name="Scheduled sync of all transactions",
|
name="Scheduled sync of all transactions",
|
||||||
max_instances=1,
|
max_instances=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.scheduler.start()
|
self.scheduler.start()
|
||||||
logger.info(f"Background scheduler started with sync job: {trigger}")
|
logger.info(f"Background scheduler started with sync job: {trigger}")
|
||||||
|
|
||||||
@@ -76,28 +59,9 @@ class BackgroundScheduler:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Configure new schedule
|
# Configure new schedule
|
||||||
if schedule_config.get("cron"):
|
trigger = self._parse_cron_config(schedule_config)
|
||||||
try:
|
if not trigger:
|
||||||
cron_parts = schedule_config["cron"].split()
|
return
|
||||||
if len(cron_parts) == 5:
|
|
||||||
minute, hour, day, month, day_of_week = cron_parts
|
|
||||||
trigger = CronTrigger(
|
|
||||||
minute=minute,
|
|
||||||
hour=hour,
|
|
||||||
day=day if day != "*" else None,
|
|
||||||
month=month if month != "*" else None,
|
|
||||||
day_of_week=day_of_week if day_of_week != "*" else None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.error(f"Invalid cron expression: {schedule_config['cron']}")
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error parsing cron expression: {e}")
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
hour = schedule_config.get("hour", 3)
|
|
||||||
minute = schedule_config.get("minute", 0)
|
|
||||||
trigger = CronTrigger(hour=hour, minute=minute)
|
|
||||||
|
|
||||||
self.scheduler.add_job(
|
self.scheduler.add_job(
|
||||||
self._run_sync,
|
self._run_sync,
|
||||||
@@ -108,13 +72,90 @@ class BackgroundScheduler:
|
|||||||
)
|
)
|
||||||
logger.info(f"Rescheduled sync job with: {trigger}")
|
logger.info(f"Rescheduled sync job with: {trigger}")
|
||||||
|
|
||||||
async def _run_sync(self):
|
def _parse_cron_config(self, schedule_config: dict) -> CronTrigger:
|
||||||
|
"""Parse cron configuration and return CronTrigger"""
|
||||||
|
if schedule_config.get("cron"):
|
||||||
|
# Parse custom cron expression (e.g., "0 3 * * *" for daily at 3 AM)
|
||||||
|
try:
|
||||||
|
cron_parts = schedule_config["cron"].split()
|
||||||
|
if len(cron_parts) == 5:
|
||||||
|
minute, hour, day, month, day_of_week = cron_parts
|
||||||
|
return CronTrigger(
|
||||||
|
minute=minute,
|
||||||
|
hour=hour,
|
||||||
|
day=day if day != "*" else None,
|
||||||
|
month=month if month != "*" else None,
|
||||||
|
day_of_week=day_of_week if day_of_week != "*" else None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(f"Invalid cron expression: {schedule_config['cron']}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error parsing cron expression: {e}")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
# Use hour/minute configuration (default: 3:00 AM daily)
|
||||||
|
hour = schedule_config.get("hour", 3)
|
||||||
|
minute = schedule_config.get("minute", 0)
|
||||||
|
return CronTrigger(hour=hour, minute=minute)
|
||||||
|
|
||||||
|
async def _run_sync(self, retry_count: int = 0):
|
||||||
|
"""Run sync with enhanced error handling and retry logic"""
|
||||||
try:
|
try:
|
||||||
logger.info("Starting scheduled sync job")
|
logger.info("Starting scheduled sync job")
|
||||||
await self.sync_service.sync_all_accounts()
|
await self.sync_service.sync_all_accounts()
|
||||||
logger.info("Scheduled sync job completed successfully")
|
logger.info("Scheduled sync job completed successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Scheduled sync job failed: {e}")
|
logger.error(
|
||||||
|
f"Scheduled sync job failed (attempt {retry_count + 1}/{self.max_retries}): {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send notification about the failure
|
||||||
|
try:
|
||||||
|
await self.notification_service.send_expiry_notification(
|
||||||
|
{
|
||||||
|
"type": "sync_failure",
|
||||||
|
"error": str(e),
|
||||||
|
"retry_count": retry_count + 1,
|
||||||
|
"max_retries": self.max_retries,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as notification_error:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to send failure notification: {notification_error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Implement retry logic for transient failures
|
||||||
|
if retry_count < self.max_retries - 1:
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
logger.info(f"Retrying sync job in {self.retry_delay} seconds...")
|
||||||
|
# Schedule a retry
|
||||||
|
retry_time = datetime.datetime.now() + datetime.timedelta(
|
||||||
|
seconds=self.retry_delay
|
||||||
|
)
|
||||||
|
self.scheduler.add_job(
|
||||||
|
self._run_sync,
|
||||||
|
"date",
|
||||||
|
args=[retry_count + 1],
|
||||||
|
id=f"sync_retry_{retry_count + 1}",
|
||||||
|
run_date=retry_time,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error("Maximum retries exceeded for sync job")
|
||||||
|
# Send final failure notification
|
||||||
|
try:
|
||||||
|
await self.notification_service.send_expiry_notification(
|
||||||
|
{
|
||||||
|
"type": "sync_final_failure",
|
||||||
|
"error": str(e),
|
||||||
|
"retry_count": retry_count + 1,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as notification_error:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to send final failure notification: {notification_error}"
|
||||||
|
)
|
||||||
|
|
||||||
def get_next_sync_time(self):
|
def get_next_sync_time(self):
|
||||||
"""Get the next scheduled sync time"""
|
"""Get the next scheduled sync time"""
|
||||||
@@ -124,4 +165,4 @@ class BackgroundScheduler:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
scheduler = BackgroundScheduler()
|
scheduler = BackgroundScheduler()
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class Config:
|
|||||||
if config_path is None:
|
if config_path is None:
|
||||||
config_path = os.environ.get(
|
config_path = os.environ.get(
|
||||||
"LEGGEN_CONFIG_FILE",
|
"LEGGEN_CONFIG_FILE",
|
||||||
str(Path.home() / ".config" / "leggen" / "config.toml")
|
str(Path.home() / ".config" / "leggen" / "config.toml"),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._config_path = config_path
|
self._config_path = config_path
|
||||||
@@ -42,15 +42,17 @@ class Config:
|
|||||||
|
|
||||||
return self._config
|
return self._config
|
||||||
|
|
||||||
def save_config(self, config_data: Dict[str, Any] = None, config_path: str = None) -> None:
|
def save_config(
|
||||||
|
self, config_data: Dict[str, Any] = None, config_path: str = None
|
||||||
|
) -> None:
|
||||||
"""Save configuration to TOML file"""
|
"""Save configuration to TOML file"""
|
||||||
if config_data is None:
|
if config_data is None:
|
||||||
config_data = self._config
|
config_data = self._config
|
||||||
|
|
||||||
if config_path is None:
|
if config_path is None:
|
||||||
config_path = self._config_path or os.environ.get(
|
config_path = self._config_path or os.environ.get(
|
||||||
"LEGGEN_CONFIG_FILE",
|
"LEGGEN_CONFIG_FILE",
|
||||||
str(Path.home() / ".config" / "leggen" / "config.toml")
|
str(Path.home() / ".config" / "leggen" / "config.toml"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure directory exists
|
# Ensure directory exists
|
||||||
@@ -59,7 +61,7 @@ class Config:
|
|||||||
try:
|
try:
|
||||||
with open(config_path, "wb") as f:
|
with open(config_path, "wb") as f:
|
||||||
tomli_w.dump(config_data, f)
|
tomli_w.dump(config_data, f)
|
||||||
|
|
||||||
# Update in-memory config
|
# Update in-memory config
|
||||||
self._config = config_data
|
self._config = config_data
|
||||||
self._config_path = config_path
|
self._config_path = config_path
|
||||||
@@ -72,10 +74,10 @@ class Config:
|
|||||||
"""Update a specific configuration value"""
|
"""Update a specific configuration value"""
|
||||||
if self._config is None:
|
if self._config is None:
|
||||||
self.load_config()
|
self.load_config()
|
||||||
|
|
||||||
if section not in self._config:
|
if section not in self._config:
|
||||||
self._config[section] = {}
|
self._config[section] = {}
|
||||||
|
|
||||||
self._config[section][key] = value
|
self._config[section][key] = value
|
||||||
self.save_config()
|
self.save_config()
|
||||||
|
|
||||||
@@ -83,7 +85,7 @@ class Config:
|
|||||||
"""Update an entire configuration section"""
|
"""Update an entire configuration section"""
|
||||||
if self._config is None:
|
if self._config is None:
|
||||||
self.load_config()
|
self.load_config()
|
||||||
|
|
||||||
self._config[section] = data
|
self._config[section] = data
|
||||||
self.save_config()
|
self.save_config()
|
||||||
|
|
||||||
@@ -117,10 +119,10 @@ class Config:
|
|||||||
"enabled": True,
|
"enabled": True,
|
||||||
"hour": 3,
|
"hour": 3,
|
||||||
"minute": 0,
|
"minute": 0,
|
||||||
"cron": None # Optional custom cron expression
|
"cron": None, # Optional custom cron expression
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return self.config.get("scheduler", default_schedule)
|
return self.config.get("scheduler", default_schedule)
|
||||||
|
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from importlib import metadata
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
@@ -15,7 +16,7 @@ from leggend.config import config
|
|||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# Startup
|
# Startup
|
||||||
logger.info("Starting leggend service...")
|
logger.info("Starting leggend service...")
|
||||||
|
|
||||||
# Load configuration
|
# Load configuration
|
||||||
try:
|
try:
|
||||||
config.load_config()
|
config.load_config()
|
||||||
@@ -27,26 +28,35 @@ async def lifespan(app: FastAPI):
|
|||||||
# Start background scheduler
|
# Start background scheduler
|
||||||
scheduler.start()
|
scheduler.start()
|
||||||
logger.info("Background scheduler started")
|
logger.info("Background scheduler started")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Shutdown
|
# Shutdown
|
||||||
logger.info("Shutting down leggend service...")
|
logger.info("Shutting down leggend service...")
|
||||||
scheduler.shutdown()
|
scheduler.shutdown()
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
def create_app() -> FastAPI:
|
||||||
|
# Get version dynamically from package metadata
|
||||||
|
try:
|
||||||
|
version = metadata.version("leggen")
|
||||||
|
except metadata.PackageNotFoundError:
|
||||||
|
version = "unknown"
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Leggend API",
|
title="Leggend API",
|
||||||
description="Open Banking API for Leggen",
|
description="Open Banking API for Leggen",
|
||||||
version="0.6.11",
|
version=version,
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add CORS middleware
|
# Add CORS middleware
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["http://localhost:3000", "http://localhost:5173"], # SvelteKit dev servers
|
allow_origins=[
|
||||||
|
"http://localhost:3000",
|
||||||
|
"http://localhost:5173",
|
||||||
|
], # SvelteKit dev servers
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
@@ -60,7 +70,12 @@ def create_app() -> FastAPI:
|
|||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def root():
|
async def root():
|
||||||
return {"message": "Leggend API is running", "version": "0.6.11"}
|
# Get version dynamically
|
||||||
|
try:
|
||||||
|
version = metadata.version("leggen")
|
||||||
|
except metadata.PackageNotFoundError:
|
||||||
|
version = "unknown"
|
||||||
|
return {"message": "Leggend API is running", "version": version}
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health():
|
async def health():
|
||||||
@@ -71,25 +86,19 @@ def create_app() -> FastAPI:
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Start the Leggend API service")
|
parser = argparse.ArgumentParser(description="Start the Leggend API service")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--reload",
|
"--reload", action="store_true", help="Enable auto-reload for development"
|
||||||
action="store_true",
|
|
||||||
help="Enable auto-reload for development"
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--host",
|
"--host", default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)"
|
||||||
default="0.0.0.0",
|
|
||||||
help="Host to bind to (default: 0.0.0.0)"
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--port",
|
"--port", type=int, default=8000, help="Port to bind to (default: 8000)"
|
||||||
type=int,
|
|
||||||
default=8000,
|
|
||||||
help="Port to bind to (default: 8000)"
|
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.reload:
|
if args.reload:
|
||||||
# Use string import for reload to work properly
|
# Use string import for reload to work properly
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
@@ -114,4 +123,4 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -11,7 +11,9 @@ class DatabaseService:
|
|||||||
self.db_config = config.database_config
|
self.db_config = config.database_config
|
||||||
self.sqlite_enabled = self.db_config.get("sqlite", True)
|
self.sqlite_enabled = self.db_config.get("sqlite", True)
|
||||||
|
|
||||||
async def persist_balance(self, account_id: str, balance_data: Dict[str, Any]) -> None:
|
async def persist_balance(
|
||||||
|
self, account_id: str, balance_data: Dict[str, Any]
|
||||||
|
) -> None:
|
||||||
"""Persist account balance data"""
|
"""Persist account balance data"""
|
||||||
if not self.sqlite_enabled:
|
if not self.sqlite_enabled:
|
||||||
logger.warning("SQLite database disabled, skipping balance persistence")
|
logger.warning("SQLite database disabled, skipping balance persistence")
|
||||||
@@ -19,7 +21,9 @@ class DatabaseService:
|
|||||||
|
|
||||||
await self._persist_balance_sqlite(account_id, balance_data)
|
await self._persist_balance_sqlite(account_id, balance_data)
|
||||||
|
|
||||||
async def persist_transactions(self, account_id: str, transactions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
async def persist_transactions(
|
||||||
|
self, account_id: str, transactions: List[Dict[str, Any]]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
"""Persist transactions and return new transactions"""
|
"""Persist transactions and return new transactions"""
|
||||||
if not self.sqlite_enabled:
|
if not self.sqlite_enabled:
|
||||||
logger.warning("SQLite database disabled, skipping transaction persistence")
|
logger.warning("SQLite database disabled, skipping transaction persistence")
|
||||||
@@ -27,32 +31,48 @@ class DatabaseService:
|
|||||||
|
|
||||||
return await self._persist_transactions_sqlite(account_id, transactions)
|
return await self._persist_transactions_sqlite(account_id, transactions)
|
||||||
|
|
||||||
def process_transactions(self, account_id: str, account_info: Dict[str, Any], transaction_data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
def process_transactions(
|
||||||
|
self,
|
||||||
|
account_id: str,
|
||||||
|
account_info: Dict[str, Any],
|
||||||
|
transaction_data: Dict[str, Any],
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
"""Process raw transaction data into standardized format"""
|
"""Process raw transaction data into standardized format"""
|
||||||
transactions = []
|
transactions = []
|
||||||
|
|
||||||
# Process booked transactions
|
# Process booked transactions
|
||||||
for transaction in transaction_data.get("transactions", {}).get("booked", []):
|
for transaction in transaction_data.get("transactions", {}).get("booked", []):
|
||||||
processed = self._process_single_transaction(account_id, account_info, transaction, "booked")
|
processed = self._process_single_transaction(
|
||||||
|
account_id, account_info, transaction, "booked"
|
||||||
|
)
|
||||||
transactions.append(processed)
|
transactions.append(processed)
|
||||||
|
|
||||||
# Process pending transactions
|
# Process pending transactions
|
||||||
for transaction in transaction_data.get("transactions", {}).get("pending", []):
|
for transaction in transaction_data.get("transactions", {}).get("pending", []):
|
||||||
processed = self._process_single_transaction(account_id, account_info, transaction, "pending")
|
processed = self._process_single_transaction(
|
||||||
|
account_id, account_info, transaction, "pending"
|
||||||
|
)
|
||||||
transactions.append(processed)
|
transactions.append(processed)
|
||||||
|
|
||||||
return transactions
|
return transactions
|
||||||
|
|
||||||
def _process_single_transaction(self, account_id: str, account_info: Dict[str, Any], transaction: Dict[str, Any], status: str) -> Dict[str, Any]:
|
def _process_single_transaction(
|
||||||
|
self,
|
||||||
|
account_id: str,
|
||||||
|
account_info: Dict[str, Any],
|
||||||
|
transaction: Dict[str, Any],
|
||||||
|
status: str,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Process a single transaction into standardized format"""
|
"""Process a single transaction into standardized format"""
|
||||||
# Extract dates
|
# Extract dates
|
||||||
booked_date = transaction.get("bookingDateTime") or transaction.get("bookingDate")
|
booked_date = transaction.get("bookingDateTime") or transaction.get(
|
||||||
|
"bookingDate"
|
||||||
|
)
|
||||||
value_date = transaction.get("valueDateTime") or transaction.get("valueDate")
|
value_date = transaction.get("valueDateTime") or transaction.get("valueDate")
|
||||||
|
|
||||||
if booked_date and value_date:
|
if booked_date and value_date:
|
||||||
min_date = min(
|
min_date = min(
|
||||||
datetime.fromisoformat(booked_date),
|
datetime.fromisoformat(booked_date), datetime.fromisoformat(value_date)
|
||||||
datetime.fromisoformat(value_date)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
min_date = datetime.fromisoformat(booked_date or value_date)
|
min_date = datetime.fromisoformat(booked_date or value_date)
|
||||||
@@ -65,7 +85,7 @@ class DatabaseService:
|
|||||||
# Extract description
|
# Extract description
|
||||||
description = transaction.get(
|
description = transaction.get(
|
||||||
"remittanceInformationUnstructured",
|
"remittanceInformationUnstructured",
|
||||||
",".join(transaction.get("remittanceInformationUnstructuredArray", []))
|
",".join(transaction.get("remittanceInformationUnstructuredArray", [])),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -81,13 +101,19 @@ class DatabaseService:
|
|||||||
"rawTransaction": transaction,
|
"rawTransaction": transaction,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _persist_balance_sqlite(self, account_id: str, balance_data: Dict[str, Any]) -> None:
|
async def _persist_balance_sqlite(
|
||||||
|
self, account_id: str, balance_data: Dict[str, Any]
|
||||||
|
) -> None:
|
||||||
"""Persist balance to SQLite - placeholder implementation"""
|
"""Persist balance to SQLite - placeholder implementation"""
|
||||||
# Would import and use leggen.database.sqlite
|
# Would import and use leggen.database.sqlite
|
||||||
logger.info(f"Persisting balance to SQLite for account {account_id}")
|
logger.info(f"Persisting balance to SQLite for account {account_id}")
|
||||||
|
|
||||||
async def _persist_transactions_sqlite(self, account_id: str, transactions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
async def _persist_transactions_sqlite(
|
||||||
|
self, account_id: str, transactions: List[Dict[str, Any]]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
"""Persist transactions to SQLite - placeholder implementation"""
|
"""Persist transactions to SQLite - placeholder implementation"""
|
||||||
# Would import and use leggen.database.sqlite
|
# Would import and use leggen.database.sqlite
|
||||||
logger.info(f"Persisting {len(transactions)} transactions to SQLite for account {account_id}")
|
logger.info(
|
||||||
return transactions # Return new transactions for notifications
|
f"Persisting {len(transactions)} transactions to SQLite for account {account_id}"
|
||||||
|
)
|
||||||
|
return transactions # Return new transactions for notifications
|
||||||
|
|||||||
@@ -12,37 +12,36 @@ from leggend.config import config
|
|||||||
class GoCardlessService:
|
class GoCardlessService:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.config = config.gocardless_config
|
self.config = config.gocardless_config
|
||||||
self.base_url = self.config.get("url", "https://bankaccountdata.gocardless.com/api/v2")
|
self.base_url = self.config.get(
|
||||||
|
"url", "https://bankaccountdata.gocardless.com/api/v2"
|
||||||
|
)
|
||||||
self._token = None
|
self._token = None
|
||||||
|
|
||||||
async def _get_auth_headers(self) -> Dict[str, str]:
|
async def _get_auth_headers(self) -> Dict[str, str]:
|
||||||
"""Get authentication headers for GoCardless API"""
|
"""Get authentication headers for GoCardless API"""
|
||||||
token = await self._get_token()
|
token = await self._get_token()
|
||||||
return {
|
return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
||||||
"Authorization": f"Bearer {token}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _get_token(self) -> str:
|
async def _get_token(self) -> str:
|
||||||
"""Get access token for GoCardless API"""
|
"""Get access token for GoCardless API"""
|
||||||
if self._token:
|
if self._token:
|
||||||
return self._token
|
return self._token
|
||||||
|
|
||||||
# Use ~/.config/leggen for consistency with main config
|
# Use ~/.config/leggen for consistency with main config
|
||||||
auth_file = Path.home() / ".config" / "leggen" / "auth.json"
|
auth_file = Path.home() / ".config" / "leggen" / "auth.json"
|
||||||
|
|
||||||
if auth_file.exists():
|
if auth_file.exists():
|
||||||
try:
|
try:
|
||||||
with open(auth_file, "r") as f:
|
with open(auth_file, "r") as f:
|
||||||
auth = json.load(f)
|
auth = json.load(f)
|
||||||
|
|
||||||
if auth.get("access"):
|
if auth.get("access"):
|
||||||
# Try to refresh the token
|
# Try to refresh the token
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
try:
|
try:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self.base_url}/token/refresh/",
|
f"{self.base_url}/token/refresh/",
|
||||||
json={"refresh": auth["refresh"]}
|
json={"refresh": auth["refresh"]},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
auth.update(response.json())
|
auth.update(response.json())
|
||||||
@@ -84,7 +83,7 @@ class GoCardlessService:
|
|||||||
"""Save authentication data to file"""
|
"""Save authentication data to file"""
|
||||||
auth_file = Path.home() / ".config" / "leggen" / "auth.json"
|
auth_file = Path.home() / ".config" / "leggen" / "auth.json"
|
||||||
auth_file.parent.mkdir(parents=True, exist_ok=True)
|
auth_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
with open(auth_file, "w") as f:
|
with open(auth_file, "w") as f:
|
||||||
json.dump(auth_data, f)
|
json.dump(auth_data, f)
|
||||||
|
|
||||||
@@ -95,22 +94,21 @@ class GoCardlessService:
|
|||||||
response = await client.get(
|
response = await client.get(
|
||||||
f"{self.base_url}/institutions/",
|
f"{self.base_url}/institutions/",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
params={"country": country}
|
params={"country": country},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
async def create_requisition(self, institution_id: str, redirect_url: str) -> Dict[str, Any]:
|
async def create_requisition(
|
||||||
|
self, institution_id: str, redirect_url: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Create a bank connection requisition"""
|
"""Create a bank connection requisition"""
|
||||||
headers = await self._get_auth_headers()
|
headers = await self._get_auth_headers()
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self.base_url}/requisitions/",
|
f"{self.base_url}/requisitions/",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
json={
|
json={"institution_id": institution_id, "redirect": redirect_url},
|
||||||
"institution_id": institution_id,
|
|
||||||
"redirect": redirect_url
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
@@ -120,8 +118,7 @@ class GoCardlessService:
|
|||||||
headers = await self._get_auth_headers()
|
headers = await self._get_auth_headers()
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.get(
|
response = await client.get(
|
||||||
f"{self.base_url}/requisitions/",
|
f"{self.base_url}/requisitions/", headers=headers
|
||||||
headers=headers
|
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
@@ -131,8 +128,7 @@ class GoCardlessService:
|
|||||||
headers = await self._get_auth_headers()
|
headers = await self._get_auth_headers()
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.get(
|
response = await client.get(
|
||||||
f"{self.base_url}/accounts/{account_id}/",
|
f"{self.base_url}/accounts/{account_id}/", headers=headers
|
||||||
headers=headers
|
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
@@ -142,8 +138,7 @@ class GoCardlessService:
|
|||||||
headers = await self._get_auth_headers()
|
headers = await self._get_auth_headers()
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.get(
|
response = await client.get(
|
||||||
f"{self.base_url}/accounts/{account_id}/balances/",
|
f"{self.base_url}/accounts/{account_id}/balances/", headers=headers
|
||||||
headers=headers
|
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
@@ -153,8 +148,7 @@ class GoCardlessService:
|
|||||||
headers = await self._get_auth_headers()
|
headers = await self._get_auth_headers()
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.get(
|
response = await client.get(
|
||||||
f"{self.base_url}/accounts/{account_id}/transactions/",
|
f"{self.base_url}/accounts/{account_id}/transactions/", headers=headers
|
||||||
headers=headers
|
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|||||||
@@ -10,7 +10,9 @@ class NotificationService:
|
|||||||
self.notifications_config = config.notifications_config
|
self.notifications_config = config.notifications_config
|
||||||
self.filters_config = config.filters_config
|
self.filters_config = config.filters_config
|
||||||
|
|
||||||
async def send_transaction_notifications(self, transactions: List[Dict[str, Any]]) -> None:
|
async def send_transaction_notifications(
|
||||||
|
self, transactions: List[Dict[str, Any]]
|
||||||
|
) -> None:
|
||||||
"""Send notifications for new transactions that match filters"""
|
"""Send notifications for new transactions that match filters"""
|
||||||
if not self.filters_config:
|
if not self.filters_config:
|
||||||
logger.info("No notification filters configured, skipping notifications")
|
logger.info("No notification filters configured, skipping notifications")
|
||||||
@@ -18,7 +20,7 @@ class NotificationService:
|
|||||||
|
|
||||||
# Filter transactions that match notification criteria
|
# Filter transactions that match notification criteria
|
||||||
matching_transactions = self._filter_transactions(transactions)
|
matching_transactions = self._filter_transactions(transactions)
|
||||||
|
|
||||||
if not matching_transactions:
|
if not matching_transactions:
|
||||||
logger.info("No transactions matched notification filters")
|
logger.info("No transactions matched notification filters")
|
||||||
return
|
return
|
||||||
@@ -26,7 +28,7 @@ class NotificationService:
|
|||||||
# Send to enabled notification services
|
# Send to enabled notification services
|
||||||
if self._is_discord_enabled():
|
if self._is_discord_enabled():
|
||||||
await self._send_discord_notifications(matching_transactions)
|
await self._send_discord_notifications(matching_transactions)
|
||||||
|
|
||||||
if self._is_telegram_enabled():
|
if self._is_telegram_enabled():
|
||||||
await self._send_telegram_notifications(matching_transactions)
|
await self._send_telegram_notifications(matching_transactions)
|
||||||
|
|
||||||
@@ -40,7 +42,9 @@ class NotificationService:
|
|||||||
await self._send_telegram_test(message)
|
await self._send_telegram_test(message)
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
logger.error(f"Notification service '{service}' not enabled or not found")
|
logger.error(
|
||||||
|
f"Notification service '{service}' not enabled or not found"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to send test notification to {service}: {e}")
|
logger.error(f"Failed to send test notification to {service}: {e}")
|
||||||
@@ -50,54 +54,66 @@ class NotificationService:
|
|||||||
"""Send notification about account expiry"""
|
"""Send notification about account expiry"""
|
||||||
if self._is_discord_enabled():
|
if self._is_discord_enabled():
|
||||||
await self._send_discord_expiry(notification_data)
|
await self._send_discord_expiry(notification_data)
|
||||||
|
|
||||||
if self._is_telegram_enabled():
|
if self._is_telegram_enabled():
|
||||||
await self._send_telegram_expiry(notification_data)
|
await self._send_telegram_expiry(notification_data)
|
||||||
|
|
||||||
def _filter_transactions(self, transactions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
def _filter_transactions(
|
||||||
|
self, transactions: List[Dict[str, Any]]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
"""Filter transactions based on notification criteria"""
|
"""Filter transactions based on notification criteria"""
|
||||||
matching = []
|
matching = []
|
||||||
filters_case_insensitive = self.filters_config.get("case-insensitive", {})
|
filters_case_insensitive = self.filters_config.get("case-insensitive", {})
|
||||||
|
|
||||||
for transaction in transactions:
|
for transaction in transactions:
|
||||||
description = transaction.get("description", "").lower()
|
description = transaction.get("description", "").lower()
|
||||||
|
|
||||||
# Check case-insensitive filters
|
# Check case-insensitive filters
|
||||||
for filter_name, filter_value in filters_case_insensitive.items():
|
for filter_name, filter_value in filters_case_insensitive.items():
|
||||||
if filter_value.lower() in description:
|
if filter_value.lower() in description:
|
||||||
matching.append({
|
matching.append(
|
||||||
"name": transaction["description"],
|
{
|
||||||
"value": transaction["transactionValue"],
|
"name": transaction["description"],
|
||||||
"currency": transaction["transactionCurrency"],
|
"value": transaction["transactionValue"],
|
||||||
"date": transaction["transactionDate"],
|
"currency": transaction["transactionCurrency"],
|
||||||
})
|
"date": transaction["transactionDate"],
|
||||||
|
}
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
return matching
|
return matching
|
||||||
|
|
||||||
def _is_discord_enabled(self) -> bool:
|
def _is_discord_enabled(self) -> bool:
|
||||||
"""Check if Discord notifications are enabled"""
|
"""Check if Discord notifications are enabled"""
|
||||||
discord_config = self.notifications_config.get("discord", {})
|
discord_config = self.notifications_config.get("discord", {})
|
||||||
return bool(discord_config.get("webhook") and discord_config.get("enabled", True))
|
return bool(
|
||||||
|
discord_config.get("webhook") and discord_config.get("enabled", True)
|
||||||
|
)
|
||||||
|
|
||||||
def _is_telegram_enabled(self) -> bool:
|
def _is_telegram_enabled(self) -> bool:
|
||||||
"""Check if Telegram notifications are enabled"""
|
"""Check if Telegram notifications are enabled"""
|
||||||
telegram_config = self.notifications_config.get("telegram", {})
|
telegram_config = self.notifications_config.get("telegram", {})
|
||||||
return bool(
|
return bool(
|
||||||
telegram_config.get("token") and
|
telegram_config.get("token")
|
||||||
telegram_config.get("chat_id") and
|
and telegram_config.get("chat_id")
|
||||||
telegram_config.get("enabled", True)
|
and telegram_config.get("enabled", True)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _send_discord_notifications(self, transactions: List[Dict[str, Any]]) -> None:
|
async def _send_discord_notifications(
|
||||||
|
self, transactions: List[Dict[str, Any]]
|
||||||
|
) -> None:
|
||||||
"""Send Discord notifications - placeholder implementation"""
|
"""Send Discord notifications - placeholder implementation"""
|
||||||
# Would import and use leggen.notifications.discord
|
# Would import and use leggen.notifications.discord
|
||||||
logger.info(f"Sending {len(transactions)} transaction notifications to Discord")
|
logger.info(f"Sending {len(transactions)} transaction notifications to Discord")
|
||||||
|
|
||||||
async def _send_telegram_notifications(self, transactions: List[Dict[str, Any]]) -> None:
|
async def _send_telegram_notifications(
|
||||||
|
self, transactions: List[Dict[str, Any]]
|
||||||
|
) -> None:
|
||||||
"""Send Telegram notifications - placeholder implementation"""
|
"""Send Telegram notifications - placeholder implementation"""
|
||||||
# Would import and use leggen.notifications.telegram
|
# Would import and use leggen.notifications.telegram
|
||||||
logger.info(f"Sending {len(transactions)} transaction notifications to Telegram")
|
logger.info(
|
||||||
|
f"Sending {len(transactions)} transaction notifications to Telegram"
|
||||||
|
)
|
||||||
|
|
||||||
async def _send_discord_test(self, message: str) -> None:
|
async def _send_discord_test(self, message: str) -> None:
|
||||||
"""Send Discord test notification"""
|
"""Send Discord test notification"""
|
||||||
@@ -113,4 +129,4 @@ class NotificationService:
|
|||||||
|
|
||||||
async def _send_telegram_expiry(self, notification_data: Dict[str, Any]) -> None:
|
async def _send_telegram_expiry(self, notification_data: Dict[str, Any]) -> None:
|
||||||
"""Send Telegram expiry notification"""
|
"""Send Telegram expiry notification"""
|
||||||
logger.info(f"Sending Telegram expiry notification: {notification_data}")
|
logger.info(f"Sending Telegram expiry notification: {notification_data}")
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ class SyncService:
|
|||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
self._sync_status.is_running = True
|
self._sync_status.is_running = True
|
||||||
self._sync_status.errors = []
|
self._sync_status.errors = []
|
||||||
|
|
||||||
accounts_processed = 0
|
accounts_processed = 0
|
||||||
transactions_added = 0
|
transactions_added = 0
|
||||||
transactions_updated = 0
|
transactions_updated = 0
|
||||||
@@ -39,22 +39,24 @@ class SyncService:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info("Starting sync of all accounts")
|
logger.info("Starting sync of all accounts")
|
||||||
|
|
||||||
# Get all requisitions and accounts
|
# Get all requisitions and accounts
|
||||||
requisitions = await self.gocardless.get_requisitions()
|
requisitions = await self.gocardless.get_requisitions()
|
||||||
all_accounts = set()
|
all_accounts = set()
|
||||||
|
|
||||||
for req in requisitions.get("results", []):
|
for req in requisitions.get("results", []):
|
||||||
all_accounts.update(req.get("accounts", []))
|
all_accounts.update(req.get("accounts", []))
|
||||||
|
|
||||||
self._sync_status.total_accounts = len(all_accounts)
|
self._sync_status.total_accounts = len(all_accounts)
|
||||||
|
|
||||||
# Process each account
|
# Process each account
|
||||||
for account_id in all_accounts:
|
for account_id in all_accounts:
|
||||||
try:
|
try:
|
||||||
# Get account details
|
# Get account details
|
||||||
account_details = await self.gocardless.get_account_details(account_id)
|
account_details = await self.gocardless.get_account_details(
|
||||||
|
account_id
|
||||||
|
)
|
||||||
|
|
||||||
# Get and save balances
|
# Get and save balances
|
||||||
balances = await self.gocardless.get_account_balances(account_id)
|
balances = await self.gocardless.get_account_balances(account_id)
|
||||||
if balances:
|
if balances:
|
||||||
@@ -62,7 +64,9 @@ class SyncService:
|
|||||||
balances_updated += len(balances.get("balances", []))
|
balances_updated += len(balances.get("balances", []))
|
||||||
|
|
||||||
# Get and save transactions
|
# Get and save transactions
|
||||||
transactions = await self.gocardless.get_account_transactions(account_id)
|
transactions = await self.gocardless.get_account_transactions(
|
||||||
|
account_id
|
||||||
|
)
|
||||||
if transactions:
|
if transactions:
|
||||||
processed_transactions = self.database.process_transactions(
|
processed_transactions = self.database.process_transactions(
|
||||||
account_id, account_details, transactions
|
account_id, account_details, transactions
|
||||||
@@ -71,16 +75,18 @@ class SyncService:
|
|||||||
account_id, processed_transactions
|
account_id, processed_transactions
|
||||||
)
|
)
|
||||||
transactions_added += len(new_transactions)
|
transactions_added += len(new_transactions)
|
||||||
|
|
||||||
# Send notifications for new transactions
|
# Send notifications for new transactions
|
||||||
if new_transactions:
|
if new_transactions:
|
||||||
await self.notifications.send_transaction_notifications(new_transactions)
|
await self.notifications.send_transaction_notifications(
|
||||||
|
new_transactions
|
||||||
|
)
|
||||||
|
|
||||||
accounts_processed += 1
|
accounts_processed += 1
|
||||||
self._sync_status.accounts_synced = accounts_processed
|
self._sync_status.accounts_synced = accounts_processed
|
||||||
|
|
||||||
logger.info(f"Synced account {account_id} successfully")
|
logger.info(f"Synced account {account_id} successfully")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Failed to sync account {account_id}: {str(e)}"
|
error_msg = f"Failed to sync account {account_id}: {str(e)}"
|
||||||
errors.append(error_msg)
|
errors.append(error_msg)
|
||||||
@@ -88,9 +94,9 @@ class SyncService:
|
|||||||
|
|
||||||
end_time = datetime.now()
|
end_time = datetime.now()
|
||||||
duration = (end_time - start_time).total_seconds()
|
duration = (end_time - start_time).total_seconds()
|
||||||
|
|
||||||
self._sync_status.last_sync = end_time
|
self._sync_status.last_sync = end_time
|
||||||
|
|
||||||
result = SyncResult(
|
result = SyncResult(
|
||||||
success=len(errors) == 0,
|
success=len(errors) == 0,
|
||||||
accounts_processed=accounts_processed,
|
accounts_processed=accounts_processed,
|
||||||
@@ -100,12 +106,14 @@ class SyncService:
|
|||||||
duration_seconds=duration,
|
duration_seconds=duration,
|
||||||
errors=errors,
|
errors=errors,
|
||||||
started_at=start_time,
|
started_at=start_time,
|
||||||
completed_at=end_time
|
completed_at=end_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Sync completed: {accounts_processed} accounts, {transactions_added} new transactions"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Sync completed: {accounts_processed} accounts, {transactions_added} new transactions")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Sync failed: {str(e)}"
|
error_msg = f"Sync failed: {str(e)}"
|
||||||
errors.append(error_msg)
|
errors.append(error_msg)
|
||||||
@@ -114,7 +122,9 @@ class SyncService:
|
|||||||
finally:
|
finally:
|
||||||
self._sync_status.is_running = False
|
self._sync_status.is_running = False
|
||||||
|
|
||||||
async def sync_specific_accounts(self, account_ids: List[str], force: bool = False) -> SyncResult:
|
async def sync_specific_accounts(
|
||||||
|
self, account_ids: List[str], force: bool = False
|
||||||
|
) -> SyncResult:
|
||||||
"""Sync specific accounts"""
|
"""Sync specific accounts"""
|
||||||
if self._sync_status.is_running and not force:
|
if self._sync_status.is_running and not force:
|
||||||
raise Exception("Sync is already running")
|
raise Exception("Sync is already running")
|
||||||
@@ -123,12 +133,12 @@ class SyncService:
|
|||||||
# For brevity, implementing a simplified version
|
# For brevity, implementing a simplified version
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
self._sync_status.is_running = True
|
self._sync_status.is_running = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Process only specified accounts
|
# Process only specified accounts
|
||||||
# Implementation would be similar to sync_all_accounts
|
# Implementation would be similar to sync_all_accounts
|
||||||
# but filtered to only the specified account_ids
|
# but filtered to only the specified account_ids
|
||||||
|
|
||||||
end_time = datetime.now()
|
end_time = datetime.now()
|
||||||
return SyncResult(
|
return SyncResult(
|
||||||
success=True,
|
success=True,
|
||||||
@@ -139,7 +149,7 @@ class SyncService:
|
|||||||
duration_seconds=(end_time - start_time).total_seconds(),
|
duration_seconds=(end_time - start_time).total_seconds(),
|
||||||
errors=[],
|
errors=[],
|
||||||
started_at=start_time,
|
started_at=start_time,
|
||||||
completed_at=end_time
|
completed_at=end_time,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
self._sync_status.is_running = False
|
self._sync_status.is_running = False
|
||||||
|
|||||||
@@ -7,4 +7,4 @@ REQUISITION_STATUS = {
|
|||||||
"GA": "GRANTING_ACCESS",
|
"GA": "GRANTING_ACCESS",
|
||||||
"LN": "LINKED",
|
"LN": "LINKED",
|
||||||
"EX": "EXPIRED",
|
"EX": "EXPIRED",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Pytest configuration and shared fixtures."""
|
"""Pytest configuration and shared fixtures."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import tempfile
|
import tempfile
|
||||||
import json
|
import json
|
||||||
@@ -19,34 +20,27 @@ def temp_config_dir():
|
|||||||
yield config_dir
|
yield config_dir
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_config(temp_config_dir):
|
def mock_config(temp_config_dir):
|
||||||
"""Mock configuration for testing."""
|
"""Mock configuration for testing."""
|
||||||
config_data = {
|
config_data = {
|
||||||
"gocardless": {
|
"gocardless": {
|
||||||
"key": "test-key",
|
"key": "test-key",
|
||||||
"secret": "test-secret",
|
"secret": "test-secret",
|
||||||
"url": "https://bankaccountdata.gocardless.com/api/v2"
|
"url": "https://bankaccountdata.gocardless.com/api/v2",
|
||||||
},
|
},
|
||||||
"database": {
|
"database": {"sqlite": True},
|
||||||
"sqlite": True
|
"scheduler": {"sync": {"enabled": True, "hour": 3, "minute": 0}},
|
||||||
},
|
|
||||||
"scheduler": {
|
|
||||||
"sync": {
|
|
||||||
"enabled": True,
|
|
||||||
"hour": 3,
|
|
||||||
"minute": 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
config_file = temp_config_dir / "config.toml"
|
config_file = temp_config_dir / "config.toml"
|
||||||
with open(config_file, "wb") as f:
|
with open(config_file, "wb") as f:
|
||||||
import tomli_w
|
import tomli_w
|
||||||
|
|
||||||
tomli_w.dump(config_data, f)
|
tomli_w.dump(config_data, f)
|
||||||
|
|
||||||
# Mock the config path
|
# Mock the config path
|
||||||
with patch.object(Config, 'load_config') as mock_load:
|
with patch.object(Config, "load_config") as mock_load:
|
||||||
mock_load.return_value = config_data
|
mock_load.return_value = config_data
|
||||||
config = Config()
|
config = Config()
|
||||||
config._config = config_data
|
config._config = config_data
|
||||||
@@ -57,15 +51,12 @@ def mock_config(temp_config_dir):
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_auth_token(temp_config_dir):
|
def mock_auth_token(temp_config_dir):
|
||||||
"""Mock GoCardless authentication token."""
|
"""Mock GoCardless authentication token."""
|
||||||
auth_data = {
|
auth_data = {"access": "mock-access-token", "refresh": "mock-refresh-token"}
|
||||||
"access": "mock-access-token",
|
|
||||||
"refresh": "mock-refresh-token"
|
auth_file = temp_config_dir / "auth.json"
|
||||||
}
|
|
||||||
|
|
||||||
auth_file = temp_config_dir / "auth.json"
|
|
||||||
with open(auth_file, "w") as f:
|
with open(auth_file, "w") as f:
|
||||||
json.dump(auth_data, f)
|
json.dump(auth_data, f)
|
||||||
|
|
||||||
return auth_data
|
return auth_data
|
||||||
|
|
||||||
|
|
||||||
@@ -88,17 +79,17 @@ def sample_bank_data():
|
|||||||
{
|
{
|
||||||
"id": "REVOLUT_REVOLT21",
|
"id": "REVOLUT_REVOLT21",
|
||||||
"name": "Revolut",
|
"name": "Revolut",
|
||||||
"bic": "REVOLT21",
|
"bic": "REVOLT21",
|
||||||
"transaction_total_days": 90,
|
"transaction_total_days": 90,
|
||||||
"countries": ["GB", "LT"]
|
"countries": ["GB", "LT"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "BANCOBPI_BBPIPTPL",
|
"id": "BANCOBPI_BBPIPTPL",
|
||||||
"name": "Banco BPI",
|
"name": "Banco BPI",
|
||||||
"bic": "BBPIPTPL",
|
"bic": "BBPIPTPL",
|
||||||
"transaction_total_days": 90,
|
"transaction_total_days": 90,
|
||||||
"countries": ["PT"]
|
"countries": ["PT"],
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -107,31 +98,28 @@ def sample_account_data():
|
|||||||
"""Sample account data for testing."""
|
"""Sample account data for testing."""
|
||||||
return {
|
return {
|
||||||
"id": "test-account-123",
|
"id": "test-account-123",
|
||||||
"institution_id": "REVOLUT_REVOLT21",
|
"institution_id": "REVOLUT_REVOLT21",
|
||||||
"status": "READY",
|
"status": "READY",
|
||||||
"iban": "LT313250081177977789",
|
"iban": "LT313250081177977789",
|
||||||
"created": "2024-02-13T23:56:00Z",
|
"created": "2024-02-13T23:56:00Z",
|
||||||
"last_accessed": "2025-09-01T09:30:00Z"
|
"last_accessed": "2025-09-01T09:30:00Z",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_transaction_data():
|
def sample_transaction_data():
|
||||||
"""Sample transaction data for testing."""
|
"""Sample transaction data for testing."""
|
||||||
return {
|
return {
|
||||||
"transactions": {
|
"transactions": {
|
||||||
"booked": [
|
"booked": [
|
||||||
{
|
{
|
||||||
"internalTransactionId": "txn-123",
|
"internalTransactionId": "txn-123",
|
||||||
"bookingDate": "2025-09-01",
|
"bookingDate": "2025-09-01",
|
||||||
"valueDate": "2025-09-01",
|
"valueDate": "2025-09-01",
|
||||||
"transactionAmount": {
|
"transactionAmount": {"amount": "-10.50", "currency": "EUR"},
|
||||||
"amount": "-10.50",
|
"remittanceInformationUnstructured": "Coffee Shop Payment",
|
||||||
"currency": "EUR"
|
|
||||||
},
|
|
||||||
"remittanceInformationUnstructured": "Coffee Shop Payment"
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"pending": []
|
"pending": [],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Tests for accounts API endpoints."""
|
"""Tests for accounts API endpoints."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import respx
|
import respx
|
||||||
import httpx
|
import httpx
|
||||||
@@ -8,48 +9,47 @@ from unittest.mock import patch
|
|||||||
@pytest.mark.api
|
@pytest.mark.api
|
||||||
class TestAccountsAPI:
|
class TestAccountsAPI:
|
||||||
"""Test account-related API endpoints."""
|
"""Test account-related API endpoints."""
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_get_all_accounts_success(self, api_client, mock_config, mock_auth_token, sample_account_data):
|
def test_get_all_accounts_success(
|
||||||
|
self, api_client, mock_config, mock_auth_token, sample_account_data
|
||||||
|
):
|
||||||
"""Test successful retrieval of all accounts."""
|
"""Test successful retrieval of all accounts."""
|
||||||
requisitions_data = {
|
requisitions_data = {
|
||||||
"results": [
|
"results": [{"id": "req-123", "accounts": ["test-account-123"]}]
|
||||||
{
|
|
||||||
"id": "req-123",
|
|
||||||
"accounts": ["test-account-123"]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
balances_data = {
|
balances_data = {
|
||||||
"balances": [
|
"balances": [
|
||||||
{
|
{
|
||||||
"balanceAmount": {"amount": "100.50", "currency": "EUR"},
|
"balanceAmount": {"amount": "100.50", "currency": "EUR"},
|
||||||
"balanceType": "interimAvailable",
|
"balanceType": "interimAvailable",
|
||||||
"lastChangeDateTime": "2025-09-01T09:30:00Z"
|
"lastChangeDateTime": "2025-09-01T09:30:00Z",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
# Mock GoCardless token creation
|
# Mock GoCardless token creation
|
||||||
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
||||||
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"})
|
return_value=httpx.Response(
|
||||||
|
200, json={"access": "test-token", "refresh": "test-refresh"}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock GoCardless API calls
|
# Mock GoCardless API calls
|
||||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/requisitions/").mock(
|
respx.get("https://bankaccountdata.gocardless.com/api/v2/requisitions/").mock(
|
||||||
return_value=httpx.Response(200, json=requisitions_data)
|
return_value=httpx.Response(200, json=requisitions_data)
|
||||||
)
|
)
|
||||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/").mock(
|
respx.get(
|
||||||
return_value=httpx.Response(200, json=sample_account_data)
|
"https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/"
|
||||||
)
|
).mock(return_value=httpx.Response(200, json=sample_account_data))
|
||||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/").mock(
|
respx.get(
|
||||||
return_value=httpx.Response(200, json=balances_data)
|
"https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/"
|
||||||
)
|
).mock(return_value=httpx.Response(200, json=balances_data))
|
||||||
|
|
||||||
with patch('leggend.config.config', mock_config):
|
with patch("leggend.config.config", mock_config):
|
||||||
response = api_client.get("/api/v1/accounts")
|
response = api_client.get("/api/v1/accounts")
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["success"] is True
|
assert data["success"] is True
|
||||||
@@ -61,33 +61,37 @@ class TestAccountsAPI:
|
|||||||
assert account["balances"][0]["amount"] == 100.50
|
assert account["balances"][0]["amount"] == 100.50
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_get_account_details_success(self, api_client, mock_config, mock_auth_token, sample_account_data):
|
def test_get_account_details_success(
|
||||||
|
self, api_client, mock_config, mock_auth_token, sample_account_data
|
||||||
|
):
|
||||||
"""Test successful retrieval of specific account details."""
|
"""Test successful retrieval of specific account details."""
|
||||||
balances_data = {
|
balances_data = {
|
||||||
"balances": [
|
"balances": [
|
||||||
{
|
{
|
||||||
"balanceAmount": {"amount": "250.75", "currency": "EUR"},
|
"balanceAmount": {"amount": "250.75", "currency": "EUR"},
|
||||||
"balanceType": "interimAvailable"
|
"balanceType": "interimAvailable",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
# Mock GoCardless token creation
|
# Mock GoCardless token creation
|
||||||
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
||||||
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"})
|
return_value=httpx.Response(
|
||||||
|
200, json={"access": "test-token", "refresh": "test-refresh"}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock GoCardless API calls
|
# Mock GoCardless API calls
|
||||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/").mock(
|
respx.get(
|
||||||
return_value=httpx.Response(200, json=sample_account_data)
|
"https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/"
|
||||||
)
|
).mock(return_value=httpx.Response(200, json=sample_account_data))
|
||||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/").mock(
|
respx.get(
|
||||||
return_value=httpx.Response(200, json=balances_data)
|
"https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/"
|
||||||
)
|
).mock(return_value=httpx.Response(200, json=balances_data))
|
||||||
|
|
||||||
with patch('leggend.config.config', mock_config):
|
with patch("leggend.config.config", mock_config):
|
||||||
response = api_client.get("/api/v1/accounts/test-account-123")
|
response = api_client.get("/api/v1/accounts/test-account-123")
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["success"] is True
|
assert data["success"] is True
|
||||||
@@ -97,35 +101,39 @@ class TestAccountsAPI:
|
|||||||
assert len(account["balances"]) == 1
|
assert len(account["balances"]) == 1
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_get_account_balances_success(self, api_client, mock_config, mock_auth_token):
|
def test_get_account_balances_success(
|
||||||
|
self, api_client, mock_config, mock_auth_token
|
||||||
|
):
|
||||||
"""Test successful retrieval of account balances."""
|
"""Test successful retrieval of account balances."""
|
||||||
balances_data = {
|
balances_data = {
|
||||||
"balances": [
|
"balances": [
|
||||||
{
|
{
|
||||||
"balanceAmount": {"amount": "1000.00", "currency": "EUR"},
|
"balanceAmount": {"amount": "1000.00", "currency": "EUR"},
|
||||||
"balanceType": "interimAvailable",
|
"balanceType": "interimAvailable",
|
||||||
"lastChangeDateTime": "2025-09-01T10:00:00Z"
|
"lastChangeDateTime": "2025-09-01T10:00:00Z",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"balanceAmount": {"amount": "950.00", "currency": "EUR"},
|
"balanceAmount": {"amount": "950.00", "currency": "EUR"},
|
||||||
"balanceType": "expected"
|
"balanceType": "expected",
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
# Mock GoCardless token creation
|
# Mock GoCardless token creation
|
||||||
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
||||||
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"})
|
return_value=httpx.Response(
|
||||||
|
200, json={"access": "test-token", "refresh": "test-refresh"}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock GoCardless API
|
# Mock GoCardless API
|
||||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/").mock(
|
respx.get(
|
||||||
return_value=httpx.Response(200, json=balances_data)
|
"https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/"
|
||||||
)
|
).mock(return_value=httpx.Response(200, json=balances_data))
|
||||||
|
|
||||||
with patch('leggend.config.config', mock_config):
|
with patch("leggend.config.config", mock_config):
|
||||||
response = api_client.get("/api/v1/accounts/test-account-123/balances")
|
response = api_client.get("/api/v1/accounts/test-account-123/balances")
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["success"] is True
|
assert data["success"] is True
|
||||||
@@ -135,29 +143,40 @@ class TestAccountsAPI:
|
|||||||
assert data["data"][0]["balance_type"] == "interimAvailable"
|
assert data["data"][0]["balance_type"] == "interimAvailable"
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_get_account_transactions_success(self, api_client, mock_config, mock_auth_token, sample_account_data, sample_transaction_data):
|
def test_get_account_transactions_success(
|
||||||
|
self,
|
||||||
|
api_client,
|
||||||
|
mock_config,
|
||||||
|
mock_auth_token,
|
||||||
|
sample_account_data,
|
||||||
|
sample_transaction_data,
|
||||||
|
):
|
||||||
"""Test successful retrieval of account transactions."""
|
"""Test successful retrieval of account transactions."""
|
||||||
# Mock GoCardless token creation
|
# Mock GoCardless token creation
|
||||||
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
||||||
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"})
|
return_value=httpx.Response(
|
||||||
|
200, json={"access": "test-token", "refresh": "test-refresh"}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock GoCardless API calls
|
# Mock GoCardless API calls
|
||||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/").mock(
|
respx.get(
|
||||||
return_value=httpx.Response(200, json=sample_account_data)
|
"https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/"
|
||||||
)
|
).mock(return_value=httpx.Response(200, json=sample_account_data))
|
||||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/transactions/").mock(
|
respx.get(
|
||||||
return_value=httpx.Response(200, json=sample_transaction_data)
|
"https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/transactions/"
|
||||||
)
|
).mock(return_value=httpx.Response(200, json=sample_transaction_data))
|
||||||
|
|
||||||
with patch('leggend.config.config', mock_config):
|
with patch("leggend.config.config", mock_config):
|
||||||
response = api_client.get("/api/v1/accounts/test-account-123/transactions?summary_only=true")
|
response = api_client.get(
|
||||||
|
"/api/v1/accounts/test-account-123/transactions?summary_only=true"
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["success"] is True
|
assert data["success"] is True
|
||||||
assert len(data["data"]) == 1
|
assert len(data["data"]) == 1
|
||||||
|
|
||||||
transaction = data["data"][0]
|
transaction = data["data"][0]
|
||||||
assert transaction["internal_transaction_id"] == "txn-123"
|
assert transaction["internal_transaction_id"] == "txn-123"
|
||||||
assert transaction["amount"] == -10.50
|
assert transaction["amount"] == -10.50
|
||||||
@@ -165,29 +184,40 @@ class TestAccountsAPI:
|
|||||||
assert transaction["description"] == "Coffee Shop Payment"
|
assert transaction["description"] == "Coffee Shop Payment"
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_get_account_transactions_full_details(self, api_client, mock_config, mock_auth_token, sample_account_data, sample_transaction_data):
|
def test_get_account_transactions_full_details(
|
||||||
|
self,
|
||||||
|
api_client,
|
||||||
|
mock_config,
|
||||||
|
mock_auth_token,
|
||||||
|
sample_account_data,
|
||||||
|
sample_transaction_data,
|
||||||
|
):
|
||||||
"""Test retrieval of full transaction details."""
|
"""Test retrieval of full transaction details."""
|
||||||
# Mock GoCardless token creation
|
# Mock GoCardless token creation
|
||||||
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
||||||
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"})
|
return_value=httpx.Response(
|
||||||
|
200, json={"access": "test-token", "refresh": "test-refresh"}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock GoCardless API calls
|
# Mock GoCardless API calls
|
||||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/").mock(
|
respx.get(
|
||||||
return_value=httpx.Response(200, json=sample_account_data)
|
"https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/"
|
||||||
)
|
).mock(return_value=httpx.Response(200, json=sample_account_data))
|
||||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/transactions/").mock(
|
respx.get(
|
||||||
return_value=httpx.Response(200, json=sample_transaction_data)
|
"https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/transactions/"
|
||||||
)
|
).mock(return_value=httpx.Response(200, json=sample_transaction_data))
|
||||||
|
|
||||||
with patch('leggend.config.config', mock_config):
|
with patch("leggend.config.config", mock_config):
|
||||||
response = api_client.get("/api/v1/accounts/test-account-123/transactions?summary_only=false")
|
response = api_client.get(
|
||||||
|
"/api/v1/accounts/test-account-123/transactions?summary_only=false"
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["success"] is True
|
assert data["success"] is True
|
||||||
assert len(data["data"]) == 1
|
assert len(data["data"]) == 1
|
||||||
|
|
||||||
transaction = data["data"][0]
|
transaction = data["data"][0]
|
||||||
assert transaction["internal_transaction_id"] == "txn-123"
|
assert transaction["internal_transaction_id"] == "txn-123"
|
||||||
assert transaction["institution_id"] == "REVOLUT_REVOLT21"
|
assert transaction["institution_id"] == "REVOLUT_REVOLT21"
|
||||||
@@ -200,14 +230,18 @@ class TestAccountsAPI:
|
|||||||
with respx.mock:
|
with respx.mock:
|
||||||
# Mock GoCardless token creation
|
# Mock GoCardless token creation
|
||||||
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
||||||
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"})
|
return_value=httpx.Response(
|
||||||
|
200, json={"access": "test-token", "refresh": "test-refresh"}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/nonexistent/").mock(
|
respx.get(
|
||||||
|
"https://bankaccountdata.gocardless.com/api/v2/accounts/nonexistent/"
|
||||||
|
).mock(
|
||||||
return_value=httpx.Response(404, json={"detail": "Account not found"})
|
return_value=httpx.Response(404, json={"detail": "Account not found"})
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch('leggend.config.config', mock_config):
|
with patch("leggend.config.config", mock_config):
|
||||||
response = api_client.get("/api/v1/accounts/nonexistent")
|
response = api_client.get("/api/v1/accounts/nonexistent")
|
||||||
|
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Tests for banks API endpoints."""
|
"""Tests for banks API endpoints."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import respx
|
import respx
|
||||||
import httpx
|
import httpx
|
||||||
@@ -10,23 +11,27 @@ from leggend.services.gocardless_service import GoCardlessService
|
|||||||
@pytest.mark.api
|
@pytest.mark.api
|
||||||
class TestBanksAPI:
|
class TestBanksAPI:
|
||||||
"""Test bank-related API endpoints."""
|
"""Test bank-related API endpoints."""
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_get_institutions_success(self, api_client, mock_config, mock_auth_token, sample_bank_data):
|
def test_get_institutions_success(
|
||||||
|
self, api_client, mock_config, mock_auth_token, sample_bank_data
|
||||||
|
):
|
||||||
"""Test successful retrieval of bank institutions."""
|
"""Test successful retrieval of bank institutions."""
|
||||||
# Mock GoCardless token creation/refresh
|
# Mock GoCardless token creation/refresh
|
||||||
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
||||||
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"})
|
return_value=httpx.Response(
|
||||||
|
200, json={"access": "test-token", "refresh": "test-refresh"}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock GoCardless institutions API
|
# Mock GoCardless institutions API
|
||||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/institutions/").mock(
|
respx.get("https://bankaccountdata.gocardless.com/api/v2/institutions/").mock(
|
||||||
return_value=httpx.Response(200, json=sample_bank_data)
|
return_value=httpx.Response(200, json=sample_bank_data)
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch('leggend.config.config', mock_config):
|
with patch("leggend.config.config", mock_config):
|
||||||
response = api_client.get("/api/v1/banks/institutions?country=PT")
|
response = api_client.get("/api/v1/banks/institutions?country=PT")
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["success"] is True
|
assert data["success"] is True
|
||||||
@@ -39,17 +44,19 @@ class TestBanksAPI:
|
|||||||
"""Test institutions endpoint with invalid country code."""
|
"""Test institutions endpoint with invalid country code."""
|
||||||
# Mock GoCardless token creation
|
# Mock GoCardless token creation
|
||||||
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
||||||
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"})
|
return_value=httpx.Response(
|
||||||
|
200, json={"access": "test-token", "refresh": "test-refresh"}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock empty institutions response for invalid country
|
# Mock empty institutions response for invalid country
|
||||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/institutions/").mock(
|
respx.get("https://bankaccountdata.gocardless.com/api/v2/institutions/").mock(
|
||||||
return_value=httpx.Response(200, json=[])
|
return_value=httpx.Response(200, json=[])
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch('leggend.config.config', mock_config):
|
with patch("leggend.config.config", mock_config):
|
||||||
response = api_client.get("/api/v1/banks/institutions?country=XX")
|
response = api_client.get("/api/v1/banks/institutions?country=XX")
|
||||||
|
|
||||||
# Should still work but return empty or filtered results
|
# Should still work but return empty or filtered results
|
||||||
assert response.status_code in [200, 404]
|
assert response.status_code in [200, 404]
|
||||||
|
|
||||||
@@ -61,27 +68,29 @@ class TestBanksAPI:
|
|||||||
"institution_id": "REVOLUT_REVOLT21",
|
"institution_id": "REVOLUT_REVOLT21",
|
||||||
"status": "CR",
|
"status": "CR",
|
||||||
"created": "2025-09-02T00:00:00Z",
|
"created": "2025-09-02T00:00:00Z",
|
||||||
"link": "https://example.com/auth"
|
"link": "https://example.com/auth",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Mock GoCardless token creation
|
# Mock GoCardless token creation
|
||||||
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
||||||
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"})
|
return_value=httpx.Response(
|
||||||
|
200, json={"access": "test-token", "refresh": "test-refresh"}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock GoCardless requisitions API
|
# Mock GoCardless requisitions API
|
||||||
respx.post("https://bankaccountdata.gocardless.com/api/v2/requisitions/").mock(
|
respx.post("https://bankaccountdata.gocardless.com/api/v2/requisitions/").mock(
|
||||||
return_value=httpx.Response(200, json=requisition_data)
|
return_value=httpx.Response(200, json=requisition_data)
|
||||||
)
|
)
|
||||||
|
|
||||||
request_data = {
|
request_data = {
|
||||||
"institution_id": "REVOLUT_REVOLT21",
|
"institution_id": "REVOLUT_REVOLT21",
|
||||||
"redirect_url": "http://localhost:8000/"
|
"redirect_url": "http://localhost:8000/",
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch('leggend.config.config', mock_config):
|
with patch("leggend.config.config", mock_config):
|
||||||
response = api_client.post("/api/v1/banks/connect", json=request_data)
|
response = api_client.post("/api/v1/banks/connect", json=request_data)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["success"] is True
|
assert data["success"] is True
|
||||||
@@ -95,27 +104,29 @@ class TestBanksAPI:
|
|||||||
"results": [
|
"results": [
|
||||||
{
|
{
|
||||||
"id": "req-123",
|
"id": "req-123",
|
||||||
"institution_id": "REVOLUT_REVOLT21",
|
"institution_id": "REVOLUT_REVOLT21",
|
||||||
"status": "LN",
|
"status": "LN",
|
||||||
"created": "2025-09-02T00:00:00Z",
|
"created": "2025-09-02T00:00:00Z",
|
||||||
"accounts": ["acc-123"]
|
"accounts": ["acc-123"],
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
# Mock GoCardless token creation
|
# Mock GoCardless token creation
|
||||||
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
||||||
return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"})
|
return_value=httpx.Response(
|
||||||
|
200, json={"access": "test-token", "refresh": "test-refresh"}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock GoCardless requisitions API
|
# Mock GoCardless requisitions API
|
||||||
respx.get("https://bankaccountdata.gocardless.com/api/v2/requisitions/").mock(
|
respx.get("https://bankaccountdata.gocardless.com/api/v2/requisitions/").mock(
|
||||||
return_value=httpx.Response(200, json=requisitions_data)
|
return_value=httpx.Response(200, json=requisitions_data)
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch('leggend.config.config', mock_config):
|
with patch("leggend.config.config", mock_config):
|
||||||
response = api_client.get("/api/v1/banks/status")
|
response = api_client.get("/api/v1/banks/status")
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["success"] is True
|
assert data["success"] is True
|
||||||
@@ -126,12 +137,12 @@ class TestBanksAPI:
|
|||||||
def test_get_supported_countries(self, api_client):
|
def test_get_supported_countries(self, api_client):
|
||||||
"""Test supported countries endpoint."""
|
"""Test supported countries endpoint."""
|
||||||
response = api_client.get("/api/v1/banks/countries")
|
response = api_client.get("/api/v1/banks/countries")
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["success"] is True
|
assert data["success"] is True
|
||||||
assert len(data["data"]) > 0
|
assert len(data["data"]) > 0
|
||||||
|
|
||||||
# Check some expected countries
|
# Check some expected countries
|
||||||
country_codes = [country["code"] for country in data["data"]]
|
country_codes = [country["code"] for country in data["data"]]
|
||||||
assert "PT" in country_codes
|
assert "PT" in country_codes
|
||||||
@@ -145,10 +156,10 @@ class TestBanksAPI:
|
|||||||
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock(
|
||||||
return_value=httpx.Response(401, json={"detail": "Invalid credentials"})
|
return_value=httpx.Response(401, json={"detail": "Invalid credentials"})
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch('leggend.config.config', mock_config):
|
with patch("leggend.config.config", mock_config):
|
||||||
response = api_client.get("/api/v1/banks/institutions")
|
response = api_client.get("/api/v1/banks/institutions")
|
||||||
|
|
||||||
assert response.status_code == 500
|
assert response.status_code == 500
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert "Failed to get institutions" in data["detail"]
|
assert "Failed to get institutions" in data["detail"]
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Tests for CLI API client."""
|
"""Tests for CLI API client."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests_mock
|
import requests_mock
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
@@ -13,36 +14,36 @@ class TestLeggendAPIClient:
|
|||||||
def test_health_check_success(self):
|
def test_health_check_success(self):
|
||||||
"""Test successful health check."""
|
"""Test successful health check."""
|
||||||
client = LeggendAPIClient("http://localhost:8000")
|
client = LeggendAPIClient("http://localhost:8000")
|
||||||
|
|
||||||
with requests_mock.Mocker() as m:
|
with requests_mock.Mocker() as m:
|
||||||
m.get("http://localhost:8000/health", json={"status": "healthy"})
|
m.get("http://localhost:8000/health", json={"status": "healthy"})
|
||||||
|
|
||||||
result = client.health_check()
|
result = client.health_check()
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
def test_health_check_failure(self):
|
def test_health_check_failure(self):
|
||||||
"""Test health check failure."""
|
"""Test health check failure."""
|
||||||
client = LeggendAPIClient("http://localhost:8000")
|
client = LeggendAPIClient("http://localhost:8000")
|
||||||
|
|
||||||
with requests_mock.Mocker() as m:
|
with requests_mock.Mocker() as m:
|
||||||
m.get("http://localhost:8000/health", status_code=500)
|
m.get("http://localhost:8000/health", status_code=500)
|
||||||
|
|
||||||
result = client.health_check()
|
result = client.health_check()
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
def test_get_institutions_success(self, sample_bank_data):
|
def test_get_institutions_success(self, sample_bank_data):
|
||||||
"""Test getting institutions via API client."""
|
"""Test getting institutions via API client."""
|
||||||
client = LeggendAPIClient("http://localhost:8000")
|
client = LeggendAPIClient("http://localhost:8000")
|
||||||
|
|
||||||
api_response = {
|
api_response = {
|
||||||
"success": True,
|
"success": True,
|
||||||
"data": sample_bank_data,
|
"data": sample_bank_data,
|
||||||
"message": "Found 2 institutions for PT"
|
"message": "Found 2 institutions for PT",
|
||||||
}
|
}
|
||||||
|
|
||||||
with requests_mock.Mocker() as m:
|
with requests_mock.Mocker() as m:
|
||||||
m.get("http://localhost:8000/api/v1/banks/institutions", json=api_response)
|
m.get("http://localhost:8000/api/v1/banks/institutions", json=api_response)
|
||||||
|
|
||||||
result = client.get_institutions("PT")
|
result = client.get_institutions("PT")
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
assert result[0]["id"] == "REVOLUT_REVOLT21"
|
assert result[0]["id"] == "REVOLUT_REVOLT21"
|
||||||
@@ -50,16 +51,16 @@ class TestLeggendAPIClient:
|
|||||||
def test_get_accounts_success(self, sample_account_data):
|
def test_get_accounts_success(self, sample_account_data):
|
||||||
"""Test getting accounts via API client."""
|
"""Test getting accounts via API client."""
|
||||||
client = LeggendAPIClient("http://localhost:8000")
|
client = LeggendAPIClient("http://localhost:8000")
|
||||||
|
|
||||||
api_response = {
|
api_response = {
|
||||||
"success": True,
|
"success": True,
|
||||||
"data": [sample_account_data],
|
"data": [sample_account_data],
|
||||||
"message": "Retrieved 1 accounts"
|
"message": "Retrieved 1 accounts",
|
||||||
}
|
}
|
||||||
|
|
||||||
with requests_mock.Mocker() as m:
|
with requests_mock.Mocker() as m:
|
||||||
m.get("http://localhost:8000/api/v1/accounts", json=api_response)
|
m.get("http://localhost:8000/api/v1/accounts", json=api_response)
|
||||||
|
|
||||||
result = client.get_accounts()
|
result = client.get_accounts()
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert result[0]["id"] == "test-account-123"
|
assert result[0]["id"] == "test-account-123"
|
||||||
@@ -67,34 +68,37 @@ class TestLeggendAPIClient:
|
|||||||
def test_trigger_sync_success(self):
|
def test_trigger_sync_success(self):
|
||||||
"""Test triggering sync via API client."""
|
"""Test triggering sync via API client."""
|
||||||
client = LeggendAPIClient("http://localhost:8000")
|
client = LeggendAPIClient("http://localhost:8000")
|
||||||
|
|
||||||
api_response = {
|
api_response = {
|
||||||
"success": True,
|
"success": True,
|
||||||
"data": {"sync_started": True, "force": False},
|
"data": {"sync_started": True, "force": False},
|
||||||
"message": "Started sync for all accounts"
|
"message": "Started sync for all accounts",
|
||||||
}
|
}
|
||||||
|
|
||||||
with requests_mock.Mocker() as m:
|
with requests_mock.Mocker() as m:
|
||||||
m.post("http://localhost:8000/api/v1/sync", json=api_response)
|
m.post("http://localhost:8000/api/v1/sync", json=api_response)
|
||||||
|
|
||||||
result = client.trigger_sync()
|
result = client.trigger_sync()
|
||||||
assert result["sync_started"] is True
|
assert result["sync_started"] is True
|
||||||
|
|
||||||
def test_connection_error_handling(self):
|
def test_connection_error_handling(self):
|
||||||
"""Test handling of connection errors."""
|
"""Test handling of connection errors."""
|
||||||
client = LeggendAPIClient("http://localhost:9999") # Non-existent service
|
client = LeggendAPIClient("http://localhost:9999") # Non-existent service
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
client.get_accounts()
|
client.get_accounts()
|
||||||
|
|
||||||
def test_http_error_handling(self):
|
def test_http_error_handling(self):
|
||||||
"""Test handling of HTTP errors."""
|
"""Test handling of HTTP errors."""
|
||||||
client = LeggendAPIClient("http://localhost:8000")
|
client = LeggendAPIClient("http://localhost:8000")
|
||||||
|
|
||||||
with requests_mock.Mocker() as m:
|
with requests_mock.Mocker() as m:
|
||||||
m.get("http://localhost:8000/api/v1/accounts", status_code=500,
|
m.get(
|
||||||
json={"detail": "Internal server error"})
|
"http://localhost:8000/api/v1/accounts",
|
||||||
|
status_code=500,
|
||||||
|
json={"detail": "Internal server error"},
|
||||||
|
)
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
client.get_accounts()
|
client.get_accounts()
|
||||||
|
|
||||||
@@ -102,28 +106,28 @@ class TestLeggendAPIClient:
|
|||||||
"""Test using custom API URL."""
|
"""Test using custom API URL."""
|
||||||
custom_url = "http://custom-host:9000"
|
custom_url = "http://custom-host:9000"
|
||||||
client = LeggendAPIClient(custom_url)
|
client = LeggendAPIClient(custom_url)
|
||||||
|
|
||||||
assert client.base_url == custom_url
|
assert client.base_url == custom_url
|
||||||
|
|
||||||
def test_environment_variable_url(self):
|
def test_environment_variable_url(self):
|
||||||
"""Test using environment variable for API URL."""
|
"""Test using environment variable for API URL."""
|
||||||
with patch.dict('os.environ', {'LEGGEND_API_URL': 'http://env-host:7000'}):
|
with patch.dict("os.environ", {"LEGGEND_API_URL": "http://env-host:7000"}):
|
||||||
client = LeggendAPIClient()
|
client = LeggendAPIClient()
|
||||||
assert client.base_url == "http://env-host:7000"
|
assert client.base_url == "http://env-host:7000"
|
||||||
|
|
||||||
def test_sync_with_options(self):
|
def test_sync_with_options(self):
|
||||||
"""Test sync with various options."""
|
"""Test sync with various options."""
|
||||||
client = LeggendAPIClient("http://localhost:8000")
|
client = LeggendAPIClient("http://localhost:8000")
|
||||||
|
|
||||||
api_response = {
|
api_response = {
|
||||||
"success": True,
|
"success": True,
|
||||||
"data": {"sync_started": True, "force": True},
|
"data": {"sync_started": True, "force": True},
|
||||||
"message": "Started sync for 2 specific accounts"
|
"message": "Started sync for 2 specific accounts",
|
||||||
}
|
}
|
||||||
|
|
||||||
with requests_mock.Mocker() as m:
|
with requests_mock.Mocker() as m:
|
||||||
m.post("http://localhost:8000/api/v1/sync", json=api_response)
|
m.post("http://localhost:8000/api/v1/sync", json=api_response)
|
||||||
|
|
||||||
result = client.trigger_sync(account_ids=["acc1", "acc2"], force=True)
|
result = client.trigger_sync(account_ids=["acc1", "acc2"], force=True)
|
||||||
assert result["sync_started"] is True
|
assert result["sync_started"] is True
|
||||||
assert result["force"] is True
|
assert result["force"] is True
|
||||||
@@ -131,20 +135,20 @@ class TestLeggendAPIClient:
|
|||||||
def test_get_scheduler_config(self):
|
def test_get_scheduler_config(self):
|
||||||
"""Test getting scheduler configuration."""
|
"""Test getting scheduler configuration."""
|
||||||
client = LeggendAPIClient("http://localhost:8000")
|
client = LeggendAPIClient("http://localhost:8000")
|
||||||
|
|
||||||
api_response = {
|
api_response = {
|
||||||
"success": True,
|
"success": True,
|
||||||
"data": {
|
"data": {
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"hour": 3,
|
"hour": 3,
|
||||||
"minute": 0,
|
"minute": 0,
|
||||||
"next_scheduled_sync": "2025-09-03T03:00:00Z"
|
"next_scheduled_sync": "2025-09-03T03:00:00Z",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
with requests_mock.Mocker() as m:
|
with requests_mock.Mocker() as m:
|
||||||
m.get("http://localhost:8000/api/v1/sync/scheduler", json=api_response)
|
m.get("http://localhost:8000/api/v1/sync/scheduler", json=api_response)
|
||||||
|
|
||||||
result = client.get_scheduler_config()
|
result = client.get_scheduler_config()
|
||||||
assert result["enabled"] is True
|
assert result["enabled"] is True
|
||||||
assert result["hour"] == 3
|
assert result["hour"] == 3
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Tests for configuration management."""
|
"""Tests for configuration management."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -23,25 +24,24 @@ class TestConfig:
|
|||||||
"gocardless": {
|
"gocardless": {
|
||||||
"key": "test-key",
|
"key": "test-key",
|
||||||
"secret": "test-secret",
|
"secret": "test-secret",
|
||||||
"url": "https://test.example.com"
|
"url": "https://test.example.com",
|
||||||
},
|
},
|
||||||
"database": {
|
"database": {"sqlite": True},
|
||||||
"sqlite": True
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
config_file = temp_config_dir / "config.toml"
|
config_file = temp_config_dir / "config.toml"
|
||||||
with open(config_file, "wb") as f:
|
with open(config_file, "wb") as f:
|
||||||
import tomli_w
|
import tomli_w
|
||||||
|
|
||||||
tomli_w.dump(config_data, f)
|
tomli_w.dump(config_data, f)
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
# Reset singleton state for testing
|
# Reset singleton state for testing
|
||||||
config._config = None
|
config._config = None
|
||||||
config._config_path = None
|
config._config_path = None
|
||||||
|
|
||||||
result = config.load_config(str(config_file))
|
result = config.load_config(str(config_file))
|
||||||
|
|
||||||
assert result == config_data
|
assert result == config_data
|
||||||
assert config.gocardless_config["key"] == "test-key"
|
assert config.gocardless_config["key"] == "test-key"
|
||||||
assert config.database_config["sqlite"] is True
|
assert config.database_config["sqlite"] is True
|
||||||
@@ -50,87 +50,84 @@ class TestConfig:
|
|||||||
"""Test handling of missing configuration file."""
|
"""Test handling of missing configuration file."""
|
||||||
config = Config()
|
config = Config()
|
||||||
config._config = None # Reset for test
|
config._config = None # Reset for test
|
||||||
|
|
||||||
with pytest.raises(FileNotFoundError):
|
with pytest.raises(FileNotFoundError):
|
||||||
config.load_config("/nonexistent/config.toml")
|
config.load_config("/nonexistent/config.toml")
|
||||||
|
|
||||||
def test_save_config_success(self, temp_config_dir):
|
def test_save_config_success(self, temp_config_dir):
|
||||||
"""Test successful configuration saving."""
|
"""Test successful configuration saving."""
|
||||||
config_data = {
|
config_data = {"gocardless": {"key": "new-key", "secret": "new-secret"}}
|
||||||
"gocardless": {
|
|
||||||
"key": "new-key",
|
|
||||||
"secret": "new-secret"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
config_file = temp_config_dir / "new_config.toml"
|
config_file = temp_config_dir / "new_config.toml"
|
||||||
config = Config()
|
config = Config()
|
||||||
config._config = None
|
config._config = None
|
||||||
|
|
||||||
config.save_config(config_data, str(config_file))
|
config.save_config(config_data, str(config_file))
|
||||||
|
|
||||||
# Verify file was created and contains correct data
|
# Verify file was created and contains correct data
|
||||||
assert config_file.exists()
|
assert config_file.exists()
|
||||||
|
|
||||||
import tomllib
|
import tomllib
|
||||||
|
|
||||||
with open(config_file, "rb") as f:
|
with open(config_file, "rb") as f:
|
||||||
saved_data = tomllib.load(f)
|
saved_data = tomllib.load(f)
|
||||||
|
|
||||||
assert saved_data == config_data
|
assert saved_data == config_data
|
||||||
|
|
||||||
def test_update_config_success(self, temp_config_dir):
|
def test_update_config_success(self, temp_config_dir):
|
||||||
"""Test updating configuration values."""
|
"""Test updating configuration values."""
|
||||||
initial_config = {
|
initial_config = {
|
||||||
"gocardless": {"key": "old-key"},
|
"gocardless": {"key": "old-key"},
|
||||||
"database": {"sqlite": True}
|
"database": {"sqlite": True},
|
||||||
}
|
}
|
||||||
|
|
||||||
config_file = temp_config_dir / "config.toml"
|
config_file = temp_config_dir / "config.toml"
|
||||||
with open(config_file, "wb") as f:
|
with open(config_file, "wb") as f:
|
||||||
import tomli_w
|
import tomli_w
|
||||||
|
|
||||||
tomli_w.dump(initial_config, f)
|
tomli_w.dump(initial_config, f)
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
config._config = None
|
config._config = None
|
||||||
config.load_config(str(config_file))
|
config.load_config(str(config_file))
|
||||||
|
|
||||||
config.update_config("gocardless", "key", "new-key")
|
config.update_config("gocardless", "key", "new-key")
|
||||||
|
|
||||||
assert config.gocardless_config["key"] == "new-key"
|
assert config.gocardless_config["key"] == "new-key"
|
||||||
|
|
||||||
# Verify it was saved to file
|
# Verify it was saved to file
|
||||||
import tomllib
|
import tomllib
|
||||||
|
|
||||||
with open(config_file, "rb") as f:
|
with open(config_file, "rb") as f:
|
||||||
saved_data = tomllib.load(f)
|
saved_data = tomllib.load(f)
|
||||||
assert saved_data["gocardless"]["key"] == "new-key"
|
assert saved_data["gocardless"]["key"] == "new-key"
|
||||||
|
|
||||||
def test_update_section_success(self, temp_config_dir):
|
def test_update_section_success(self, temp_config_dir):
|
||||||
"""Test updating entire configuration section."""
|
"""Test updating entire configuration section."""
|
||||||
initial_config = {
|
initial_config = {"database": {"sqlite": True}}
|
||||||
"database": {"sqlite": True}
|
|
||||||
}
|
|
||||||
|
|
||||||
config_file = temp_config_dir / "config.toml"
|
config_file = temp_config_dir / "config.toml"
|
||||||
with open(config_file, "wb") as f:
|
with open(config_file, "wb") as f:
|
||||||
import tomli_w
|
import tomli_w
|
||||||
|
|
||||||
tomli_w.dump(initial_config, f)
|
tomli_w.dump(initial_config, f)
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
config._config = None
|
config._config = None
|
||||||
config.load_config(str(config_file))
|
config.load_config(str(config_file))
|
||||||
|
|
||||||
new_db_config = {"sqlite": False, "path": "./custom.db"}
|
new_db_config = {"sqlite": False, "path": "./custom.db"}
|
||||||
config.update_section("database", new_db_config)
|
config.update_section("database", new_db_config)
|
||||||
|
|
||||||
assert config.database_config == new_db_config
|
assert config.database_config == new_db_config
|
||||||
|
|
||||||
def test_scheduler_config_defaults(self):
|
def test_scheduler_config_defaults(self):
|
||||||
"""Test scheduler configuration with defaults."""
|
"""Test scheduler configuration with defaults."""
|
||||||
config = Config()
|
config = Config()
|
||||||
config._config = {} # Empty config
|
config._config = {} # Empty config
|
||||||
|
|
||||||
scheduler_config = config.scheduler_config
|
scheduler_config = config.scheduler_config
|
||||||
|
|
||||||
assert scheduler_config["sync"]["enabled"] is True
|
assert scheduler_config["sync"]["enabled"] is True
|
||||||
assert scheduler_config["sync"]["hour"] == 3
|
assert scheduler_config["sync"]["hour"] == 3
|
||||||
assert scheduler_config["sync"]["minute"] == 0
|
assert scheduler_config["sync"]["minute"] == 0
|
||||||
@@ -144,16 +141,16 @@ class TestConfig:
|
|||||||
"enabled": False,
|
"enabled": False,
|
||||||
"hour": 6,
|
"hour": 6,
|
||||||
"minute": 30,
|
"minute": 30,
|
||||||
"cron": "0 6 * * 1-5"
|
"cron": "0 6 * * 1-5",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
config._config = custom_config
|
config._config = custom_config
|
||||||
|
|
||||||
scheduler_config = config.scheduler_config
|
scheduler_config = config.scheduler_config
|
||||||
|
|
||||||
assert scheduler_config["sync"]["enabled"] is False
|
assert scheduler_config["sync"]["enabled"] is False
|
||||||
assert scheduler_config["sync"]["hour"] == 6
|
assert scheduler_config["sync"]["hour"] == 6
|
||||||
assert scheduler_config["sync"]["minute"] == 30
|
assert scheduler_config["sync"]["minute"] == 30
|
||||||
@@ -161,26 +158,28 @@ class TestConfig:
|
|||||||
|
|
||||||
def test_environment_variable_config_path(self):
|
def test_environment_variable_config_path(self):
|
||||||
"""Test using environment variable for config path."""
|
"""Test using environment variable for config path."""
|
||||||
with patch.dict('os.environ', {'LEGGEN_CONFIG_FILE': '/custom/path/config.toml'}):
|
with patch.dict(
|
||||||
|
"os.environ", {"LEGGEN_CONFIG_FILE": "/custom/path/config.toml"}
|
||||||
|
):
|
||||||
config = Config()
|
config = Config()
|
||||||
config._config = None
|
config._config = None
|
||||||
|
|
||||||
with patch('builtins.open', side_effect=FileNotFoundError):
|
with patch("builtins.open", side_effect=FileNotFoundError):
|
||||||
with pytest.raises(FileNotFoundError):
|
with pytest.raises(FileNotFoundError):
|
||||||
config.load_config()
|
config.load_config()
|
||||||
|
|
||||||
def test_notifications_config(self):
|
def test_notifications_config(self):
|
||||||
"""Test notifications configuration access."""
|
"""Test notifications configuration access."""
|
||||||
custom_config = {
|
custom_config = {
|
||||||
"notifications": {
|
"notifications": {
|
||||||
"discord": {"webhook": "https://discord.webhook", "enabled": True},
|
"discord": {"webhook": "https://discord.webhook", "enabled": True},
|
||||||
"telegram": {"token": "bot-token", "chat_id": 123}
|
"telegram": {"token": "bot-token", "chat_id": 123},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
config._config = custom_config
|
config._config = custom_config
|
||||||
|
|
||||||
notifications = config.notifications_config
|
notifications = config.notifications_config
|
||||||
assert notifications["discord"]["webhook"] == "https://discord.webhook"
|
assert notifications["discord"]["webhook"] == "https://discord.webhook"
|
||||||
assert notifications["telegram"]["token"] == "bot-token"
|
assert notifications["telegram"]["token"] == "bot-token"
|
||||||
@@ -190,13 +189,13 @@ class TestConfig:
|
|||||||
custom_config = {
|
custom_config = {
|
||||||
"filters": {
|
"filters": {
|
||||||
"case-insensitive": {"salary": "SALARY", "bills": "BILL"},
|
"case-insensitive": {"salary": "SALARY", "bills": "BILL"},
|
||||||
"amount_threshold": 100.0
|
"amount_threshold": 100.0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
config._config = custom_config
|
config._config = custom_config
|
||||||
|
|
||||||
filters = config.filters_config
|
filters = config.filters_config
|
||||||
assert filters["case-insensitive"]["salary"] == "SALARY"
|
assert filters["case-insensitive"]["salary"] == "SALARY"
|
||||||
assert filters["amount_threshold"] == 100.0
|
assert filters["amount_threshold"] == 100.0
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Tests for background scheduler."""
|
"""Tests for background scheduler."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||||||
@@ -16,23 +17,19 @@ class TestBackgroundScheduler:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_config(self):
|
def mock_config(self):
|
||||||
"""Mock configuration for scheduler tests."""
|
"""Mock configuration for scheduler tests."""
|
||||||
return {
|
return {"sync": {"enabled": True, "hour": 3, "minute": 0, "cron": None}}
|
||||||
"sync": {
|
|
||||||
"enabled": True,
|
|
||||||
"hour": 3,
|
|
||||||
"minute": 0,
|
|
||||||
"cron": None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def scheduler(self):
|
def scheduler(self):
|
||||||
"""Create scheduler instance for testing."""
|
"""Create scheduler instance for testing."""
|
||||||
with patch('leggend.background.scheduler.SyncService'), \
|
with (
|
||||||
patch('leggend.background.scheduler.config') as mock_config:
|
patch("leggend.background.scheduler.SyncService"),
|
||||||
|
patch("leggend.background.scheduler.config") as mock_config,
|
||||||
mock_config.scheduler_config = {"sync": {"enabled": True, "hour": 3, "minute": 0}}
|
):
|
||||||
|
mock_config.scheduler_config = {
|
||||||
|
"sync": {"enabled": True, "hour": 3, "minute": 0}
|
||||||
|
}
|
||||||
|
|
||||||
# Create scheduler and replace its AsyncIO scheduler with a mock
|
# Create scheduler and replace its AsyncIO scheduler with a mock
|
||||||
scheduler = BackgroundScheduler()
|
scheduler = BackgroundScheduler()
|
||||||
mock_scheduler = MagicMock()
|
mock_scheduler = MagicMock()
|
||||||
@@ -43,35 +40,34 @@ class TestBackgroundScheduler:
|
|||||||
|
|
||||||
def test_scheduler_start_default_config(self, scheduler, mock_config):
|
def test_scheduler_start_default_config(self, scheduler, mock_config):
|
||||||
"""Test starting scheduler with default configuration."""
|
"""Test starting scheduler with default configuration."""
|
||||||
with patch('leggend.config.config') as mock_config_obj:
|
with patch("leggend.config.config") as mock_config_obj:
|
||||||
mock_config_obj.scheduler_config = mock_config
|
mock_config_obj.scheduler_config = mock_config
|
||||||
|
|
||||||
# Mock the job that gets added
|
# Mock the job that gets added
|
||||||
mock_job = MagicMock()
|
mock_job = MagicMock()
|
||||||
mock_job.id = "daily_sync"
|
mock_job.id = "daily_sync"
|
||||||
scheduler.scheduler.get_jobs.return_value = [mock_job]
|
scheduler.scheduler.get_jobs.return_value = [mock_job]
|
||||||
|
|
||||||
scheduler.start()
|
scheduler.start()
|
||||||
|
|
||||||
# Verify scheduler.start() was called
|
# Verify scheduler.start() was called
|
||||||
scheduler.scheduler.start.assert_called_once()
|
scheduler.scheduler.start.assert_called_once()
|
||||||
# Verify add_job was called
|
# Verify add_job was called
|
||||||
scheduler.scheduler.add_job.assert_called_once()
|
scheduler.scheduler.add_job.assert_called_once()
|
||||||
|
|
||||||
def test_scheduler_start_disabled(self, scheduler):
|
def test_scheduler_start_disabled(self, scheduler):
|
||||||
"""Test scheduler behavior when sync is disabled."""
|
"""Test scheduler behavior when sync is disabled."""
|
||||||
disabled_config = {
|
disabled_config = {"sync": {"enabled": False}}
|
||||||
"sync": {"enabled": False}
|
|
||||||
}
|
with (
|
||||||
|
patch.object(scheduler, "scheduler") as mock_scheduler,
|
||||||
with patch.object(scheduler, 'scheduler') as mock_scheduler, \
|
patch("leggend.background.scheduler.config") as mock_config_obj,
|
||||||
patch('leggend.background.scheduler.config') as mock_config_obj:
|
):
|
||||||
|
|
||||||
mock_config_obj.scheduler_config = disabled_config
|
mock_config_obj.scheduler_config = disabled_config
|
||||||
mock_scheduler.running = False
|
mock_scheduler.running = False
|
||||||
|
|
||||||
scheduler.start()
|
scheduler.start()
|
||||||
|
|
||||||
# Verify scheduler.start() was called
|
# Verify scheduler.start() was called
|
||||||
mock_scheduler.start.assert_called_once()
|
mock_scheduler.start.assert_called_once()
|
||||||
# Verify add_job was NOT called for disabled sync
|
# Verify add_job was NOT called for disabled sync
|
||||||
@@ -82,39 +78,35 @@ class TestBackgroundScheduler:
|
|||||||
cron_config = {
|
cron_config = {
|
||||||
"sync": {
|
"sync": {
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"cron": "0 6 * * 1-5" # 6 AM on weekdays
|
"cron": "0 6 * * 1-5", # 6 AM on weekdays
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch('leggend.config.config') as mock_config_obj:
|
with patch("leggend.config.config") as mock_config_obj:
|
||||||
mock_config_obj.scheduler_config = cron_config
|
mock_config_obj.scheduler_config = cron_config
|
||||||
|
|
||||||
scheduler.start()
|
scheduler.start()
|
||||||
|
|
||||||
# Verify scheduler.start() and add_job were called
|
# Verify scheduler.start() and add_job were called
|
||||||
scheduler.scheduler.start.assert_called_once()
|
scheduler.scheduler.start.assert_called_once()
|
||||||
scheduler.scheduler.add_job.assert_called_once()
|
scheduler.scheduler.add_job.assert_called_once()
|
||||||
# Verify job was added with correct ID
|
# Verify job was added with correct ID
|
||||||
call_args = scheduler.scheduler.add_job.call_args
|
call_args = scheduler.scheduler.add_job.call_args
|
||||||
assert call_args.kwargs['id'] == 'daily_sync'
|
assert call_args.kwargs["id"] == "daily_sync"
|
||||||
|
|
||||||
def test_scheduler_start_invalid_cron(self, scheduler):
|
def test_scheduler_start_invalid_cron(self, scheduler):
|
||||||
"""Test handling of invalid cron expressions."""
|
"""Test handling of invalid cron expressions."""
|
||||||
invalid_cron_config = {
|
invalid_cron_config = {"sync": {"enabled": True, "cron": "invalid cron"}}
|
||||||
"sync": {
|
|
||||||
"enabled": True,
|
with (
|
||||||
"cron": "invalid cron"
|
patch.object(scheduler, "scheduler") as mock_scheduler,
|
||||||
}
|
patch("leggend.background.scheduler.config") as mock_config_obj,
|
||||||
}
|
):
|
||||||
|
|
||||||
with patch.object(scheduler, 'scheduler') as mock_scheduler, \
|
|
||||||
patch('leggend.background.scheduler.config') as mock_config_obj:
|
|
||||||
|
|
||||||
mock_config_obj.scheduler_config = invalid_cron_config
|
mock_config_obj.scheduler_config = invalid_cron_config
|
||||||
mock_scheduler.running = False
|
mock_scheduler.running = False
|
||||||
|
|
||||||
scheduler.start()
|
scheduler.start()
|
||||||
|
|
||||||
# With invalid cron, scheduler.start() should not be called due to early return
|
# With invalid cron, scheduler.start() should not be called due to early return
|
||||||
# and add_job should not be called
|
# and add_job should not be called
|
||||||
mock_scheduler.start.assert_not_called()
|
mock_scheduler.start.assert_not_called()
|
||||||
@@ -123,24 +115,20 @@ class TestBackgroundScheduler:
|
|||||||
def test_scheduler_shutdown(self, scheduler):
|
def test_scheduler_shutdown(self, scheduler):
|
||||||
"""Test scheduler shutdown."""
|
"""Test scheduler shutdown."""
|
||||||
scheduler.scheduler.running = True
|
scheduler.scheduler.running = True
|
||||||
|
|
||||||
scheduler.shutdown()
|
scheduler.shutdown()
|
||||||
|
|
||||||
scheduler.scheduler.shutdown.assert_called_once()
|
scheduler.scheduler.shutdown.assert_called_once()
|
||||||
|
|
||||||
def test_reschedule_sync(self, scheduler, mock_config):
|
def test_reschedule_sync(self, scheduler, mock_config):
|
||||||
"""Test rescheduling sync job."""
|
"""Test rescheduling sync job."""
|
||||||
scheduler.scheduler.running = True
|
scheduler.scheduler.running = True
|
||||||
|
|
||||||
# Reschedule with new config
|
# Reschedule with new config
|
||||||
new_config = {
|
new_config = {"enabled": True, "hour": 6, "minute": 30}
|
||||||
"enabled": True,
|
|
||||||
"hour": 6,
|
|
||||||
"minute": 30
|
|
||||||
}
|
|
||||||
|
|
||||||
scheduler.reschedule_sync(new_config)
|
scheduler.reschedule_sync(new_config)
|
||||||
|
|
||||||
# Verify remove_job and add_job were called
|
# Verify remove_job and add_job were called
|
||||||
scheduler.scheduler.remove_job.assert_called_once_with("daily_sync")
|
scheduler.scheduler.remove_job.assert_called_once_with("daily_sync")
|
||||||
scheduler.scheduler.add_job.assert_called_once()
|
scheduler.scheduler.add_job.assert_called_once()
|
||||||
@@ -148,11 +136,11 @@ class TestBackgroundScheduler:
|
|||||||
def test_reschedule_sync_disable(self, scheduler, mock_config):
|
def test_reschedule_sync_disable(self, scheduler, mock_config):
|
||||||
"""Test disabling sync via reschedule."""
|
"""Test disabling sync via reschedule."""
|
||||||
scheduler.scheduler.running = True
|
scheduler.scheduler.running = True
|
||||||
|
|
||||||
# Disable sync
|
# Disable sync
|
||||||
disabled_config = {"enabled": False}
|
disabled_config = {"enabled": False}
|
||||||
scheduler.reschedule_sync(disabled_config)
|
scheduler.reschedule_sync(disabled_config)
|
||||||
|
|
||||||
# Job should be removed but not re-added
|
# Job should be removed but not re-added
|
||||||
scheduler.scheduler.remove_job.assert_called_once_with("daily_sync")
|
scheduler.scheduler.remove_job.assert_called_once_with("daily_sync")
|
||||||
scheduler.scheduler.add_job.assert_not_called()
|
scheduler.scheduler.add_job.assert_not_called()
|
||||||
@@ -162,9 +150,9 @@ class TestBackgroundScheduler:
|
|||||||
mock_job = MagicMock()
|
mock_job = MagicMock()
|
||||||
mock_job.next_run_time = datetime(2025, 9, 2, 3, 0)
|
mock_job.next_run_time = datetime(2025, 9, 2, 3, 0)
|
||||||
scheduler.scheduler.get_job.return_value = mock_job
|
scheduler.scheduler.get_job.return_value = mock_job
|
||||||
|
|
||||||
next_time = scheduler.get_next_sync_time()
|
next_time = scheduler.get_next_sync_time()
|
||||||
|
|
||||||
assert next_time is not None
|
assert next_time is not None
|
||||||
assert isinstance(next_time, datetime)
|
assert isinstance(next_time, datetime)
|
||||||
scheduler.scheduler.get_job.assert_called_once_with("daily_sync")
|
scheduler.scheduler.get_job.assert_called_once_with("daily_sync")
|
||||||
@@ -172,9 +160,9 @@ class TestBackgroundScheduler:
|
|||||||
def test_get_next_sync_time_no_job(self, scheduler):
|
def test_get_next_sync_time_no_job(self, scheduler):
|
||||||
"""Test getting next sync time when no job is scheduled."""
|
"""Test getting next sync time when no job is scheduled."""
|
||||||
scheduler.scheduler.get_job.return_value = None
|
scheduler.scheduler.get_job.return_value = None
|
||||||
|
|
||||||
next_time = scheduler.get_next_sync_time()
|
next_time = scheduler.get_next_sync_time()
|
||||||
|
|
||||||
assert next_time is None
|
assert next_time is None
|
||||||
scheduler.scheduler.get_job.assert_called_once_with("daily_sync")
|
scheduler.scheduler.get_job.assert_called_once_with("daily_sync")
|
||||||
|
|
||||||
@@ -183,9 +171,9 @@ class TestBackgroundScheduler:
|
|||||||
"""Test successful sync job execution."""
|
"""Test successful sync job execution."""
|
||||||
mock_sync_service = AsyncMock()
|
mock_sync_service = AsyncMock()
|
||||||
scheduler.sync_service = mock_sync_service
|
scheduler.sync_service = mock_sync_service
|
||||||
|
|
||||||
await scheduler._run_sync()
|
await scheduler._run_sync()
|
||||||
|
|
||||||
mock_sync_service.sync_all_accounts.assert_called_once()
|
mock_sync_service.sync_all_accounts.assert_called_once()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -194,18 +182,18 @@ class TestBackgroundScheduler:
|
|||||||
mock_sync_service = AsyncMock()
|
mock_sync_service = AsyncMock()
|
||||||
mock_sync_service.sync_all_accounts.side_effect = Exception("Sync failed")
|
mock_sync_service.sync_all_accounts.side_effect = Exception("Sync failed")
|
||||||
scheduler.sync_service = mock_sync_service
|
scheduler.sync_service = mock_sync_service
|
||||||
|
|
||||||
# Should not raise exception, just log error
|
# Should not raise exception, just log error
|
||||||
await scheduler._run_sync()
|
await scheduler._run_sync()
|
||||||
|
|
||||||
mock_sync_service.sync_all_accounts.assert_called_once()
|
mock_sync_service.sync_all_accounts.assert_called_once()
|
||||||
|
|
||||||
def test_scheduler_job_max_instances(self, scheduler, mock_config):
|
def test_scheduler_job_max_instances(self, scheduler, mock_config):
|
||||||
"""Test that sync jobs have max_instances=1."""
|
"""Test that sync jobs have max_instances=1."""
|
||||||
with patch('leggend.config.config') as mock_config_obj:
|
with patch("leggend.config.config") as mock_config_obj:
|
||||||
mock_config_obj.scheduler_config = mock_config
|
mock_config_obj.scheduler_config = mock_config
|
||||||
scheduler.start()
|
scheduler.start()
|
||||||
|
|
||||||
# Verify add_job was called with max_instances=1
|
# Verify add_job was called with max_instances=1
|
||||||
call_args = scheduler.scheduler.add_job.call_args
|
call_args = scheduler.scheduler.add_job.call_args
|
||||||
assert call_args.kwargs['max_instances'] == 1
|
assert call_args.kwargs["max_instances"] == 1
|
||||||
|
|||||||
Reference in New Issue
Block a user