From de3da84dffd83e0b232cf76836935a66eb704aee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elisi=C3=A1rio=20Couto?= Date: Wed, 3 Sep 2025 21:11:19 +0100 Subject: [PATCH] chore: Implement code review suggestions and format code. --- Dockerfile.leggend | 32 ---- compose.yml | 15 -- leggen/api_client.py | 77 ++++++--- leggen/commands/balances.py | 12 +- leggen/commands/bank/__init__.py | 36 ---- leggen/commands/bank/add.py | 30 ++-- leggen/commands/status.py | 6 +- leggen/commands/sync.py | 28 ++-- leggen/commands/transactions.py | 54 +++--- leggen/main.py | 6 +- leggend/__init__.py | 0 leggend/api/__init__.py | 0 leggend/api/models/__init__.py | 0 leggend/api/models/accounts.py | 28 ++-- leggend/api/models/banks.py | 16 +- leggend/api/models/common.py | 5 +- leggend/api/models/notifications.py | 8 +- leggend/api/models/sync.py | 20 +-- leggend/api/routes/__init__.py | 0 leggend/api/routes/accounts.py | 149 ++++++++++------- leggend/api/routes/banks.py | 89 +++++----- leggend/api/routes/notifications.py | 121 ++++++++------ leggend/api/routes/sync.py | 133 ++++++++------- leggend/api/routes/transactions.py | 167 +++++++++++-------- leggend/background/__init__.py | 0 leggend/background/scheduler.py | 145 ++++++++++------ leggend/config.py | 22 +-- leggend/main.py | 45 +++-- leggend/services/__init__.py | 0 leggend/services/database_service.py | 60 +++++-- leggend/services/gocardless_service.py | 44 +++-- leggend/services/notification_service.py | 62 ++++--- leggend/services/sync_service.py | 54 +++--- leggend/utils/__init__.py | 0 leggend/utils/gocardless.py | 2 +- tests/__init__.py | 0 tests/conftest.py | 66 +++----- tests/unit/test_api_accounts.py | 204 +++++++++++++---------- tests/unit/test_api_banks.py | 81 +++++---- tests/unit/test_api_client.py | 70 ++++---- tests/unit/test_config.py | 99 ++++++----- tests/unit/test_scheduler.py | 124 +++++++------- 42 files changed, 1144 insertions(+), 966 deletions(-) delete mode 100644 Dockerfile.leggend delete mode 100644 leggen/commands/bank/__init__.py delete mode 100644 leggend/__init__.py delete mode 100644 leggend/api/__init__.py delete mode 100644 leggend/api/models/__init__.py delete mode 100644 leggend/api/routes/__init__.py delete mode 100644 leggend/background/__init__.py delete mode 100644 leggend/services/__init__.py delete mode 100644 leggend/utils/__init__.py delete mode 100644 tests/__init__.py diff --git a/Dockerfile.leggend b/Dockerfile.leggend deleted file mode 100644 index 231e765..0000000 --- a/Dockerfile.leggend +++ /dev/null @@ -1,32 +0,0 @@ -FROM python:3.13-alpine AS builder -COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ - -WORKDIR /app - -RUN --mount=type=cache,target=/root/.cache/uv \ - --mount=type=bind,source=uv.lock,target=uv.lock \ - --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ - uv sync --frozen --no-install-project --no-editable - -COPY . /app - -RUN --mount=type=cache,target=/root/.cache/uv \ - uv sync --frozen --no-editable --no-group dev - -FROM python:3.13-alpine - -LABEL org.opencontainers.image.source="https://github.com/elisiariocouto/leggen" -LABEL org.opencontainers.image.authors="Elisiário Couto " -LABEL org.opencontainers.image.licenses="MIT" -LABEL org.opencontainers.image.title="leggend" -LABEL org.opencontainers.image.description="Leggen API service" -LABEL org.opencontainers.image.url="https://github.com/elisiariocouto/leggen" - -WORKDIR /app -ENV PATH="/app/.venv/bin:$PATH" - -COPY --from=builder --chown=app:app /app/.venv /app/.venv - -EXPOSE 8000 - -ENTRYPOINT ["/app/.venv/bin/leggend"] \ No newline at end of file diff --git a/compose.yml b/compose.yml index 2f19f58..956dbaf 100644 --- a/compose.yml +++ b/compose.yml @@ -3,7 +3,6 @@ services: leggend: build: context: . - dockerfile: Dockerfile.leggend restart: "unless-stopped" ports: - "127.0.0.1:8000:8000" @@ -18,20 +17,6 @@ services: timeout: 10s retries: 3 - # CLI for one-off operations (uses leggend API) - leggen: - image: elisiariocouto/leggen:latest - command: sync --wait - restart: "no" - volumes: - - "./leggen:/root/.config/leggen" - - "./db:/app" - environment: - - LEGGEND_API_URL=http://leggend:8000 - depends_on: - leggend: - condition: service_healthy - nocodb: image: nocodb/nocodb:latest restart: "unless-stopped" diff --git a/leggen/api_client.py b/leggen/api_client.py index eb25d7c..5496b52 100644 --- a/leggen/api_client.py +++ b/leggen/api_client.py @@ -8,19 +8,20 @@ from leggen.utils.text import error class LeggendAPIClient: """Client for communicating with the leggend FastAPI service""" - + def __init__(self, base_url: Optional[str] = None): - self.base_url = base_url or os.environ.get("LEGGEND_API_URL", "http://localhost:8000") + self.base_url = base_url or os.environ.get( + "LEGGEND_API_URL", "http://localhost:8000" + ) self.session = requests.Session() - self.session.headers.update({ - "Content-Type": "application/json", - "Accept": "application/json" - }) + self.session.headers.update( + {"Content-Type": "application/json", "Accept": "application/json"} + ) def _make_request(self, method: str, endpoint: str, **kwargs) -> Dict[str, Any]: """Make HTTP request to the API""" url = urljoin(self.base_url, endpoint) - + try: response = self.session.request(method, url, **kwargs) response.raise_for_status() @@ -53,15 +54,19 @@ class LeggendAPIClient: # Bank endpoints def get_institutions(self, country: str = "PT") -> List[Dict[str, Any]]: """Get bank institutions for a country""" - response = self._make_request("GET", "/api/v1/banks/institutions", params={"country": country}) + response = self._make_request( + "GET", "/api/v1/banks/institutions", params={"country": country} + ) return response.get("data", []) - def connect_to_bank(self, institution_id: str, redirect_url: str = "http://localhost:8000/") -> Dict[str, Any]: + def connect_to_bank( + self, institution_id: str, redirect_url: str = "http://localhost:8000/" + ) -> Dict[str, Any]: """Connect to a bank""" response = self._make_request( - "POST", + "POST", "/api/v1/banks/connect", - json={"institution_id": institution_id, "redirect_url": redirect_url} + json={"institution_id": institution_id, "redirect_url": redirect_url}, ) return response.get("data", {}) @@ -91,31 +96,39 @@ class LeggendAPIClient: response = self._make_request("GET", f"/api/v1/accounts/{account_id}/balances") return response.get("data", []) - def get_account_transactions(self, account_id: str, limit: int = 100, summary_only: bool = False) -> List[Dict[str, Any]]: + def get_account_transactions( + self, account_id: str, limit: int = 100, summary_only: bool = False + ) -> List[Dict[str, Any]]: """Get account transactions""" response = self._make_request( - "GET", + "GET", f"/api/v1/accounts/{account_id}/transactions", - params={"limit": limit, "summary_only": summary_only} + params={"limit": limit, "summary_only": summary_only}, ) return response.get("data", []) - # Transaction endpoints - def get_all_transactions(self, limit: int = 100, summary_only: bool = True, **filters) -> List[Dict[str, Any]]: + # Transaction endpoints + def get_all_transactions( + self, limit: int = 100, summary_only: bool = True, **filters + ) -> List[Dict[str, Any]]: """Get all transactions with optional filters""" params = {"limit": limit, "summary_only": summary_only} params.update(filters) - + response = self._make_request("GET", "/api/v1/transactions", params=params) return response.get("data", []) - def get_transaction_stats(self, days: int = 30, account_id: Optional[str] = None) -> Dict[str, Any]: + def get_transaction_stats( + self, days: int = 30, account_id: Optional[str] = None + ) -> Dict[str, Any]: """Get transaction statistics""" params = {"days": days} if account_id: params["account_id"] = account_id - - response = self._make_request("GET", "/api/v1/transactions/stats", params=params) + + response = self._make_request( + "GET", "/api/v1/transactions/stats", params=params + ) return response.get("data", {}) # Sync endpoints @@ -124,21 +137,25 @@ class LeggendAPIClient: response = self._make_request("GET", "/api/v1/sync/status") return response.get("data", {}) - def trigger_sync(self, account_ids: Optional[List[str]] = None, force: bool = False) -> Dict[str, Any]: + def trigger_sync( + self, account_ids: Optional[List[str]] = None, force: bool = False + ) -> Dict[str, Any]: """Trigger a sync""" data = {"force": force} if account_ids: data["account_ids"] = account_ids - + response = self._make_request("POST", "/api/v1/sync", json=data) return response.get("data", {}) - def sync_now(self, account_ids: Optional[List[str]] = None, force: bool = False) -> Dict[str, Any]: + def sync_now( + self, account_ids: Optional[List[str]] = None, force: bool = False + ) -> Dict[str, Any]: """Run sync synchronously""" data = {"force": force} if account_ids: data["account_ids"] = account_ids - + response = self._make_request("POST", "/api/v1/sync/now", json=data) return response.get("data", {}) @@ -147,11 +164,17 @@ class LeggendAPIClient: response = self._make_request("GET", "/api/v1/sync/scheduler") return response.get("data", {}) - def update_scheduler_config(self, enabled: bool = True, hour: int = 3, minute: int = 0, cron: Optional[str] = None) -> Dict[str, Any]: + def update_scheduler_config( + self, + enabled: bool = True, + hour: int = 3, + minute: int = 0, + cron: Optional[str] = None, + ) -> Dict[str, Any]: """Update scheduler configuration""" data = {"enabled": enabled, "hour": hour, "minute": minute} if cron: data["cron"] = cron - + response = self._make_request("PUT", "/api/v1/sync/scheduler", json=data) - return response.get("data", {}) \ No newline at end of file + return response.get("data", {}) diff --git a/leggen/commands/balances.py b/leggen/commands/balances.py index df460b0..202c304 100644 --- a/leggen/commands/balances.py +++ b/leggen/commands/balances.py @@ -12,10 +12,12 @@ def balances(ctx: click.Context): List balances of all connected accounts """ api_client = LeggendAPIClient(ctx.obj.get("api_url")) - + # Check if leggend service is available if not api_client.health_check(): - click.echo("Error: Cannot connect to leggend service. Please ensure it's running.") + click.echo( + "Error: Cannot connect to leggend service. Please ensure it's running." + ) return accounts = api_client.get_accounts() @@ -24,11 +26,7 @@ def balances(ctx: click.Context): for account in accounts: for balance in account.get("balances", []): amount = round(float(balance["amount"]), 2) - symbol = ( - "€" - if balance["currency"] == "EUR" - else f" {balance['currency']}" - ) + symbol = "€" if balance["currency"] == "EUR" else f" {balance['currency']}" amount_str = f"{amount}{symbol}" date = ( datefmt(balance.get("last_change_date")) diff --git a/leggen/commands/bank/__init__.py b/leggen/commands/bank/__init__.py deleted file mode 100644 index 94935f3..0000000 --- a/leggen/commands/bank/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -import os - -import click - -from leggen.main import cli - -cmd_folder = os.path.abspath(os.path.dirname(__file__)) - - -class BankGroup(click.Group): - def list_commands(self, ctx): - rv = [] - for filename in os.listdir(cmd_folder): - if filename.endswith(".py") and not filename.startswith("__init__"): - if filename == "list_banks.py": - rv.append("list") - else: - rv.append(filename[:-3]) - rv.sort() - return rv - - def get_command(self, ctx, name): - try: - if name == "list": - name = "list_banks" - mod = __import__(f"leggen.commands.bank.{name}", None, None, [name]) - except ImportError: - return - return getattr(mod, name) - - -@cli.group(cls=BankGroup) -@click.pass_context -def bank(ctx): - """Manage banks connections""" - return diff --git a/leggen/commands/bank/add.py b/leggen/commands/bank/add.py index daa018d..8979aa7 100644 --- a/leggen/commands/bank/add.py +++ b/leggen/commands/bank/add.py @@ -13,30 +13,32 @@ def add(ctx): Connect to a bank """ api_client = LeggendAPIClient(ctx.obj.get("api_url")) - + # Check if leggend service is available if not api_client.health_check(): - click.echo("Error: Cannot connect to leggend service. Please ensure it's running.") + click.echo( + "Error: Cannot connect to leggend service. Please ensure it's running." + ) return try: # Get supported countries countries = api_client.get_supported_countries() country_codes = [c["code"] for c in countries] - + country = click.prompt( "Bank Country", type=click.Choice(country_codes, case_sensitive=True), default="PT", ) - + info(f"Getting bank list for country: {country}") banks = api_client.get_institutions(country) - + if not banks: warning(f"No banks available for country {country}") return - + filtered_banks = [ { "id": bank["id"], @@ -46,14 +48,14 @@ def add(ctx): for bank in banks ] print_table(filtered_banks) - + allowed_ids = [str(bank["id"]) for bank in banks] bank_id = click.prompt("Bank ID", type=click.Choice(allowed_ids)) - + # Show bank details selected_bank = next(bank for bank in banks if bank["id"] == bank_id) info(f"Selected bank: {selected_bank['name']}") - + click.confirm("Do you agree to connect to this bank?", abort=True) info(f"Connecting to bank with ID: {bank_id}") @@ -65,11 +67,15 @@ def add(ctx): save_file(f"req_{result['id']}.json", result) success("Bank connection request created successfully!") - warning(f"Please open the following URL in your browser to complete the authorization:") + warning( + f"Please open the following URL in your browser to complete the authorization:" + ) click.echo(f"\n{result['link']}\n") - + info(f"Requisition ID: {result['id']}") - info("After completing the authorization, you can check the connection status with 'leggen status'") + info( + "After completing the authorization, you can check the connection status with 'leggen status'" + ) except Exception as e: click.echo(f"Error: Failed to connect to bank: {str(e)}") diff --git a/leggen/commands/status.py b/leggen/commands/status.py index da9e868..be41fab 100644 --- a/leggen/commands/status.py +++ b/leggen/commands/status.py @@ -12,10 +12,12 @@ def status(ctx: click.Context): List all connected banks and their status """ api_client = LeggendAPIClient(ctx.obj.get("api_url")) - + # Check if leggend service is available if not api_client.health_check(): - click.echo("Error: Cannot connect to leggend service. Please ensure it's running.") + click.echo( + "Error: Cannot connect to leggend service. Please ensure it's running." + ) return # Get bank connection status diff --git a/leggen/commands/sync.py b/leggen/commands/sync.py index 34ea686..4becaf8 100644 --- a/leggen/commands/sync.py +++ b/leggen/commands/sync.py @@ -6,15 +6,15 @@ from leggen.utils.text import error, info, success @cli.command() -@click.option('--wait', is_flag=True, help='Wait for sync to complete (synchronous)') -@click.option('--force', is_flag=True, help='Force sync even if already running') +@click.option("--wait", is_flag=True, help="Wait for sync to complete (synchronous)") +@click.option("--force", is_flag=True, help="Force sync even if already running") @click.pass_context def sync(ctx: click.Context, wait: bool, force: bool): """ Sync all transactions with database """ api_client = LeggendAPIClient(ctx.obj.get("api_url")) - + # Check if leggend service is available if not api_client.health_check(): error("Cannot connect to leggend service. Please ensure it's running.") @@ -25,35 +25,37 @@ def sync(ctx: click.Context, wait: bool, force: bool): # Run sync synchronously and wait for completion info("Starting synchronous sync...") result = api_client.sync_now(force=force) - + if result.get("success"): success(f"Sync completed successfully!") info(f"Accounts processed: {result.get('accounts_processed', 0)}") info(f"Transactions added: {result.get('transactions_added', 0)}") info(f"Balances updated: {result.get('balances_updated', 0)}") - if result.get('duration_seconds'): + if result.get("duration_seconds"): info(f"Duration: {result['duration_seconds']:.2f} seconds") - - if result.get('errors'): + + if result.get("errors"): error(f"Errors encountered: {len(result['errors'])}") - for err in result['errors']: + for err in result["errors"]: error(f" - {err}") else: error("Sync failed") - if result.get('errors'): - for err in result['errors']: + if result.get("errors"): + for err in result["errors"]: error(f" - {err}") else: # Trigger async sync info("Starting background sync...") result = api_client.trigger_sync(force=force) - + if result.get("sync_started"): success("Sync started successfully in the background") - info("Use 'leggen sync --wait' to run synchronously or check status with API") + info( + "Use 'leggen sync --wait' to run synchronously or check status with API" + ) else: error("Failed to start sync") - + except Exception as e: error(f"Sync failed: {str(e)}") return diff --git a/leggen/commands/transactions.py b/leggen/commands/transactions.py index 8924047..c7b5de3 100644 --- a/leggen/commands/transactions.py +++ b/leggen/commands/transactions.py @@ -7,7 +7,9 @@ from leggen.utils.text import datefmt, info, print_table @cli.command() @click.option("-a", "--account", type=str, help="Account ID") -@click.option("-l", "--limit", type=int, default=50, help="Number of transactions to show") +@click.option( + "-l", "--limit", type=int, default=50, help="Number of transactions to show" +) @click.option("--full", is_flag=True, help="Show full transaction details") @click.pass_context def transactions(ctx: click.Context, account: str, limit: int, full: bool): @@ -19,10 +21,12 @@ def transactions(ctx: click.Context, account: str, limit: int, full: bool): If the --account option is used, it will only list transactions for that account. """ api_client = LeggendAPIClient(ctx.obj.get("api_url")) - + # Check if leggend service is available if not api_client.health_check(): - click.echo("Error: Cannot connect to leggend service. Please ensure it's running.") + click.echo( + "Error: Cannot connect to leggend service. Please ensure it's running." + ) return try: @@ -32,16 +36,14 @@ def transactions(ctx: click.Context, account: str, limit: int, full: bool): transactions_data = api_client.get_account_transactions( account, limit=limit, summary_only=not full ) - + info(f"Bank: {account_details['institution_id']}") info(f"IBAN: {account_details.get('iban', 'N/A')}") - + else: # Get all transactions transactions_data = api_client.get_all_transactions( - limit=limit, - summary_only=not full, - account_id=account + limit=limit, summary_only=not full, account_id=account ) # Format transactions for display @@ -49,24 +51,32 @@ def transactions(ctx: click.Context, account: str, limit: int, full: bool): # Full transaction details formatted_transactions = [] for txn in transactions_data: - formatted_transactions.append({ - "ID": txn["internal_transaction_id"][:12] + "...", - "Date": datefmt(txn["transaction_date"]), - "Description": txn["description"][:50] + "..." if len(txn["description"]) > 50 else txn["description"], - "Amount": f"{txn['transaction_value']:.2f} {txn['transaction_currency']}", - "Status": txn["transaction_status"].upper(), - "Account": txn["account_id"][:8] + "...", - }) + formatted_transactions.append( + { + "ID": txn["internal_transaction_id"][:12] + "...", + "Date": datefmt(txn["transaction_date"]), + "Description": txn["description"][:50] + "..." + if len(txn["description"]) > 50 + else txn["description"], + "Amount": f"{txn['transaction_value']:.2f} {txn['transaction_currency']}", + "Status": txn["transaction_status"].upper(), + "Account": txn["account_id"][:8] + "...", + } + ) else: # Summary view formatted_transactions = [] for txn in transactions_data: - formatted_transactions.append({ - "Date": datefmt(txn["date"]), - "Description": txn["description"][:60] + "..." if len(txn["description"]) > 60 else txn["description"], - "Amount": f"{txn['amount']:.2f} {txn['currency']}", - "Status": txn["status"].upper(), - }) + formatted_transactions.append( + { + "Date": datefmt(txn["date"]), + "Description": txn["description"][:60] + "..." + if len(txn["description"]) > 60 + else txn["description"], + "Amount": f"{txn['amount']:.2f} {txn['currency']}", + "Status": txn["status"].upper(), + } + ) if formatted_transactions: print_table(formatted_transactions) diff --git a/leggen/main.py b/leggen/main.py index 6e674b1..8859ed4 100644 --- a/leggen/main.py +++ b/leggen/main.py @@ -90,10 +90,10 @@ class Group(click.Group): @click.option( "--api-url", type=str, - default=None, + default="http://localhost:8000", envvar="LEGGEND_API_URL", show_envvar=True, - help="URL of the leggend API service (default: http://localhost:8000)", + help="URL of the leggend API service", ) @click.group( cls=Group, @@ -113,7 +113,7 @@ def cli(ctx: click.Context, api_url: str): # Store API URL in context for commands to use if api_url: ctx.obj["api_url"] = api_url - + # For backwards compatibility, still support direct GoCardless calls # This will be used as fallback if leggend service is not available try: diff --git a/leggend/__init__.py b/leggend/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/leggend/api/__init__.py b/leggend/api/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/leggend/api/models/__init__.py b/leggend/api/models/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/leggend/api/models/accounts.py b/leggend/api/models/accounts.py index 78a37f1..6678872 100644 --- a/leggend/api/models/accounts.py +++ b/leggend/api/models/accounts.py @@ -6,19 +6,19 @@ from pydantic import BaseModel class AccountBalance(BaseModel): """Account balance model""" + amount: float currency: str balance_type: str last_change_date: Optional[datetime] = None - + class Config: - json_encoders = { - datetime: lambda v: v.isoformat() if v else None - } + json_encoders = {datetime: lambda v: v.isoformat() if v else None} class AccountDetails(BaseModel): """Account details model""" + id: str institution_id: str status: str @@ -28,15 +28,14 @@ class AccountDetails(BaseModel): created: datetime last_accessed: Optional[datetime] = None balances: List[AccountBalance] = [] - + class Config: - json_encoders = { - datetime: lambda v: v.isoformat() if v else None - } + json_encoders = {datetime: lambda v: v.isoformat() if v else None} class Transaction(BaseModel): """Transaction model""" + internal_transaction_id: str institution_id: str iban: Optional[str] = None @@ -47,15 +46,14 @@ class Transaction(BaseModel): transaction_currency: str transaction_status: str # "booked" or "pending" raw_transaction: Dict[str, Any] - + class Config: - json_encoders = { - datetime: lambda v: v.isoformat() - } + json_encoders = {datetime: lambda v: v.isoformat()} class TransactionSummary(BaseModel): """Transaction summary for lists""" + internal_transaction_id: str date: datetime description: str @@ -63,8 +61,6 @@ class TransactionSummary(BaseModel): currency: str status: str account_id: str - + class Config: - json_encoders = { - datetime: lambda v: v.isoformat() - } \ No newline at end of file + json_encoders = {datetime: lambda v: v.isoformat()} diff --git a/leggend/api/models/banks.py b/leggend/api/models/banks.py index 3c4b74a..ed02519 100644 --- a/leggend/api/models/banks.py +++ b/leggend/api/models/banks.py @@ -6,6 +6,7 @@ from pydantic import BaseModel class BankInstitution(BaseModel): """Bank institution model""" + id: str name: str bic: Optional[str] = None @@ -16,12 +17,14 @@ class BankInstitution(BaseModel): class BankConnectionRequest(BaseModel): """Request to connect to a bank""" + institution_id: str redirect_url: Optional[str] = "http://localhost:8000/" class BankRequisition(BaseModel): """Bank requisition/connection model""" + id: str institution_id: str status: str @@ -29,15 +32,14 @@ class BankRequisition(BaseModel): created: datetime link: str accounts: List[str] = [] - + class Config: - json_encoders = { - datetime: lambda v: v.isoformat() - } + json_encoders = {datetime: lambda v: v.isoformat()} class BankConnectionStatus(BaseModel): """Bank connection status response""" + bank_id: str bank_name: str status: str @@ -45,8 +47,6 @@ class BankConnectionStatus(BaseModel): created_at: datetime requisition_id: str accounts_count: int - + class Config: - json_encoders = { - datetime: lambda v: v.isoformat() - } \ No newline at end of file + json_encoders = {datetime: lambda v: v.isoformat()} diff --git a/leggend/api/models/common.py b/leggend/api/models/common.py index a6b6f5b..626368e 100644 --- a/leggend/api/models/common.py +++ b/leggend/api/models/common.py @@ -6,6 +6,7 @@ from pydantic import BaseModel class APIResponse(BaseModel): """Base API response model""" + success: bool = True message: Optional[str] = None data: Optional[Any] = None @@ -13,6 +14,7 @@ class APIResponse(BaseModel): class ErrorResponse(BaseModel): """Error response model""" + success: bool = False message: str error_code: Optional[str] = None @@ -21,7 +23,8 @@ class ErrorResponse(BaseModel): class PaginatedResponse(BaseModel): """Paginated response model""" + success: bool = True data: list pagination: Dict[str, Any] - message: Optional[str] = None \ No newline at end of file + message: Optional[str] = None diff --git a/leggend/api/models/notifications.py b/leggend/api/models/notifications.py index 34fad35..c966886 100644 --- a/leggend/api/models/notifications.py +++ b/leggend/api/models/notifications.py @@ -5,12 +5,14 @@ from pydantic import BaseModel class DiscordConfig(BaseModel): """Discord notification configuration""" + webhook: str enabled: bool = True class TelegramConfig(BaseModel): """Telegram notification configuration""" + token: str chat_id: int enabled: bool = True @@ -18,6 +20,7 @@ class TelegramConfig(BaseModel): class NotificationFilters(BaseModel): """Notification filters configuration""" + case_insensitive: Dict[str, str] = {} case_sensitive: Optional[Dict[str, str]] = None amount_threshold: Optional[float] = None @@ -26,6 +29,7 @@ class NotificationFilters(BaseModel): class NotificationSettings(BaseModel): """Complete notification settings""" + discord: Optional[DiscordConfig] = None telegram: Optional[TelegramConfig] = None filters: NotificationFilters = NotificationFilters() @@ -33,15 +37,17 @@ class NotificationSettings(BaseModel): class NotificationTest(BaseModel): """Test notification request""" + service: str # "discord" or "telegram" message: str = "Test notification from Leggen" class NotificationHistory(BaseModel): """Notification history entry""" + id: str service: str message: str status: str # "sent", "failed" sent_at: str - error: Optional[str] = None \ No newline at end of file + error: Optional[str] = None diff --git a/leggend/api/models/sync.py b/leggend/api/models/sync.py index f1753c7..44bd9c7 100644 --- a/leggend/api/models/sync.py +++ b/leggend/api/models/sync.py @@ -6,12 +6,14 @@ from pydantic import BaseModel class SyncRequest(BaseModel): """Request to trigger a sync""" + account_ids: Optional[list[str]] = None # If None, sync all accounts force: bool = False # Force sync even if recently synced class SyncStatus(BaseModel): """Sync operation status""" + is_running: bool last_sync: Optional[datetime] = None next_sync: Optional[datetime] = None @@ -19,15 +21,14 @@ class SyncStatus(BaseModel): total_accounts: int = 0 transactions_added: int = 0 errors: list[str] = [] - + class Config: - json_encoders = { - datetime: lambda v: v.isoformat() if v else None - } + json_encoders = {datetime: lambda v: v.isoformat() if v else None} class SyncResult(BaseModel): """Result of a sync operation""" + success: bool accounts_processed: int transactions_added: int @@ -37,19 +38,18 @@ class SyncResult(BaseModel): errors: list[str] = [] started_at: datetime completed_at: datetime - + class Config: - json_encoders = { - datetime: lambda v: v.isoformat() - } + json_encoders = {datetime: lambda v: v.isoformat()} class SchedulerConfig(BaseModel): """Scheduler configuration model""" + enabled: bool = True hour: Optional[int] = 3 minute: Optional[int] = 0 cron: Optional[str] = None # Custom cron expression - + class Config: - extra = "forbid" \ No newline at end of file + extra = "forbid" diff --git a/leggend/api/routes/__init__.py b/leggend/api/routes/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/leggend/api/routes/accounts.py b/leggend/api/routes/accounts.py index d919f63..57549f2 100644 --- a/leggend/api/routes/accounts.py +++ b/leggend/api/routes/accounts.py @@ -3,7 +3,12 @@ from fastapi import APIRouter, HTTPException, Query from loguru import logger from leggend.api.models.common import APIResponse -from leggend.api.models.accounts import AccountDetails, AccountBalance, Transaction, TransactionSummary +from leggend.api.models.accounts import ( + AccountDetails, + AccountBalance, + Transaction, + TransactionSummary, +) from leggend.services.gocardless_service import GoCardlessService from leggend.services.database_service import DatabaseService @@ -17,50 +22,56 @@ async def get_all_accounts() -> APIResponse: """Get all connected accounts""" try: requisitions_data = await gocardless_service.get_requisitions() - + all_accounts = set() for req in requisitions_data.get("results", []): all_accounts.update(req.get("accounts", [])) - + accounts = [] for account_id in all_accounts: try: - account_details = await gocardless_service.get_account_details(account_id) - balances_data = await gocardless_service.get_account_balances(account_id) - + account_details = await gocardless_service.get_account_details( + account_id + ) + balances_data = await gocardless_service.get_account_balances( + account_id + ) + # Process balances balances = [] for balance in balances_data.get("balances", []): balance_amount = balance["balanceAmount"] - balances.append(AccountBalance( - amount=float(balance_amount["amount"]), - currency=balance_amount["currency"], - balance_type=balance["balanceType"], - last_change_date=balance.get("lastChangeDateTime") - )) - - accounts.append(AccountDetails( - id=account_details["id"], - institution_id=account_details["institution_id"], - status=account_details["status"], - iban=account_details.get("iban"), - name=account_details.get("name"), - currency=account_details.get("currency"), - created=account_details["created"], - last_accessed=account_details.get("last_accessed"), - balances=balances - )) - + balances.append( + AccountBalance( + amount=float(balance_amount["amount"]), + currency=balance_amount["currency"], + balance_type=balance["balanceType"], + last_change_date=balance.get("lastChangeDateTime"), + ) + ) + + accounts.append( + AccountDetails( + id=account_details["id"], + institution_id=account_details["institution_id"], + status=account_details["status"], + iban=account_details.get("iban"), + name=account_details.get("name"), + currency=account_details.get("currency"), + created=account_details["created"], + last_accessed=account_details.get("last_accessed"), + balances=balances, + ) + ) + except Exception as e: logger.error(f"Failed to get details for account {account_id}: {e}") continue - + return APIResponse( - success=True, - data=accounts, - message=f"Retrieved {len(accounts)} accounts" + success=True, data=accounts, message=f"Retrieved {len(accounts)} accounts" ) - + except Exception as e: logger.error(f"Failed to get accounts: {e}") raise HTTPException(status_code=500, detail=f"Failed to get accounts: {str(e)}") @@ -72,18 +83,20 @@ async def get_account_details(account_id: str) -> APIResponse: try: account_details = await gocardless_service.get_account_details(account_id) balances_data = await gocardless_service.get_account_balances(account_id) - + # Process balances balances = [] for balance in balances_data.get("balances", []): balance_amount = balance["balanceAmount"] - balances.append(AccountBalance( - amount=float(balance_amount["amount"]), - currency=balance_amount["currency"], - balance_type=balance["balanceType"], - last_change_date=balance.get("lastChangeDateTime") - )) - + balances.append( + AccountBalance( + amount=float(balance_amount["amount"]), + currency=balance_amount["currency"], + balance_type=balance["balanceType"], + last_change_date=balance.get("lastChangeDateTime"), + ) + ) + account = AccountDetails( id=account_details["id"], institution_id=account_details["institution_id"], @@ -93,15 +106,15 @@ async def get_account_details(account_id: str) -> APIResponse: currency=account_details.get("currency"), created=account_details["created"], last_accessed=account_details.get("last_accessed"), - balances=balances + balances=balances, ) - + return APIResponse( success=True, data=account, - message=f"Account details retrieved for {account_id}" + message=f"Account details retrieved for {account_id}", ) - + except Exception as e: logger.error(f"Failed to get account details for {account_id}: {e}") raise HTTPException(status_code=404, detail=f"Account not found: {str(e)}") @@ -112,23 +125,25 @@ async def get_account_balances(account_id: str) -> APIResponse: """Get balances for a specific account""" try: balances_data = await gocardless_service.get_account_balances(account_id) - + balances = [] for balance in balances_data.get("balances", []): balance_amount = balance["balanceAmount"] - balances.append(AccountBalance( - amount=float(balance_amount["amount"]), - currency=balance_amount["currency"], - balance_type=balance["balanceType"], - last_change_date=balance.get("lastChangeDateTime") - )) - + balances.append( + AccountBalance( + amount=float(balance_amount["amount"]), + currency=balance_amount["currency"], + balance_type=balance["balanceType"], + last_change_date=balance.get("lastChangeDateTime"), + ) + ) + return APIResponse( success=True, data=balances, - message=f"Retrieved {len(balances)} balances for account {account_id}" + message=f"Retrieved {len(balances)} balances for account {account_id}", ) - + except Exception as e: logger.error(f"Failed to get balances for account {account_id}: {e}") raise HTTPException(status_code=404, detail=f"Failed to get balances: {str(e)}") @@ -139,22 +154,26 @@ async def get_account_transactions( account_id: str, limit: Optional[int] = Query(default=100, le=500), offset: Optional[int] = Query(default=0, ge=0), - summary_only: bool = Query(default=False, description="Return transaction summaries only") + summary_only: bool = Query( + default=False, description="Return transaction summaries only" + ), ) -> APIResponse: """Get transactions for a specific account""" try: account_details = await gocardless_service.get_account_details(account_id) - transactions_data = await gocardless_service.get_account_transactions(account_id) - + transactions_data = await gocardless_service.get_account_transactions( + account_id + ) + # Process transactions processed_transactions = database_service.process_transactions( account_id, account_details, transactions_data ) - + # Apply pagination total_transactions = len(processed_transactions) - paginated_transactions = processed_transactions[offset:offset + limit] - + paginated_transactions = processed_transactions[offset : offset + limit] + if summary_only: # Return simplified transaction summaries summaries = [ @@ -165,7 +184,7 @@ async def get_account_transactions( amount=txn["transactionValue"], currency=txn["transactionCurrency"], status=txn["transactionStatus"], - account_id=txn["accountId"] + account_id=txn["accountId"], ) for txn in paginated_transactions ] @@ -183,18 +202,20 @@ async def get_account_transactions( transaction_value=txn["transactionValue"], transaction_currency=txn["transactionCurrency"], transaction_status=txn["transactionStatus"], - raw_transaction=txn["rawTransaction"] + raw_transaction=txn["rawTransaction"], ) for txn in paginated_transactions ] data = transactions - + return APIResponse( success=True, data=data, - message=f"Retrieved {len(data)} transactions (showing {offset + 1}-{offset + len(data)} of {total_transactions})" + message=f"Retrieved {len(data)} transactions (showing {offset + 1}-{offset + len(data)} of {total_transactions})", ) - + except Exception as e: logger.error(f"Failed to get transactions for account {account_id}: {e}") - raise HTTPException(status_code=404, detail=f"Failed to get transactions: {str(e)}") \ No newline at end of file + raise HTTPException( + status_code=404, detail=f"Failed to get transactions: {str(e)}" + ) diff --git a/leggend/api/routes/banks.py b/leggend/api/routes/banks.py index db102de..27f4a38 100644 --- a/leggend/api/routes/banks.py +++ b/leggend/api/routes/banks.py @@ -4,10 +4,10 @@ from loguru import logger from leggend.api.models.common import APIResponse, ErrorResponse from leggend.api.models.banks import ( - BankInstitution, - BankConnectionRequest, + BankInstitution, + BankConnectionRequest, BankRequisition, - BankConnectionStatus + BankConnectionStatus, ) from leggend.services.gocardless_service import GoCardlessService from leggend.utils.gocardless import REQUISITION_STATUS @@ -18,12 +18,12 @@ gocardless_service = GoCardlessService() @router.get("/banks/institutions", response_model=APIResponse) async def get_bank_institutions( - country: str = Query(default="PT", description="Country code (e.g., PT, ES, FR)") + country: str = Query(default="PT", description="Country code (e.g., PT, ES, FR)"), ) -> APIResponse: """Get available bank institutions for a country""" try: institutions_data = await gocardless_service.get_institutions(country) - + institutions = [ BankInstitution( id=inst["id"], @@ -31,20 +31,22 @@ async def get_bank_institutions( bic=inst.get("bic"), transaction_total_days=inst["transaction_total_days"], countries=inst["countries"], - logo=inst.get("logo") + logo=inst.get("logo"), ) for inst in institutions_data ] - + return APIResponse( success=True, data=institutions, - message=f"Found {len(institutions)} institutions for {country}" + message=f"Found {len(institutions)} institutions for {country}", ) - + except Exception as e: logger.error(f"Failed to get institutions for {country}: {e}") - raise HTTPException(status_code=500, detail=f"Failed to get institutions: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to get institutions: {str(e)}" + ) @router.post("/banks/connect", response_model=APIResponse) @@ -52,28 +54,29 @@ async def connect_to_bank(request: BankConnectionRequest) -> APIResponse: """Create a connection to a bank (requisition)""" try: requisition_data = await gocardless_service.create_requisition( - request.institution_id, - request.redirect_url + request.institution_id, request.redirect_url ) - + requisition = BankRequisition( id=requisition_data["id"], institution_id=requisition_data["institution_id"], status=requisition_data["status"], created=requisition_data["created"], link=requisition_data["link"], - accounts=requisition_data.get("accounts", []) + accounts=requisition_data.get("accounts", []), ) - + return APIResponse( success=True, data=requisition, - message=f"Bank connection created. Please visit the link to authorize." + message=f"Bank connection created. Please visit the link to authorize.", ) - + except Exception as e: logger.error(f"Failed to connect to bank {request.institution_id}: {e}") - raise HTTPException(status_code=500, detail=f"Failed to connect to bank: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to connect to bank: {str(e)}" + ) @router.get("/banks/status", response_model=APIResponse) @@ -81,31 +84,37 @@ async def get_bank_connections_status() -> APIResponse: """Get status of all bank connections""" try: requisitions_data = await gocardless_service.get_requisitions() - + connections = [] for req in requisitions_data.get("results", []): status = req["status"] status_display = REQUISITION_STATUS.get(status, "UNKNOWN") - - connections.append(BankConnectionStatus( - bank_id=req["institution_id"], - bank_name=req["institution_id"], # Could be enhanced with actual bank names - status=status, - status_display=status_display, - created_at=req["created"], - requisition_id=req["id"], - accounts_count=len(req.get("accounts", [])) - )) - + + connections.append( + BankConnectionStatus( + bank_id=req["institution_id"], + bank_name=req[ + "institution_id" + ], # Could be enhanced with actual bank names + status=status, + status_display=status_display, + created_at=req["created"], + requisition_id=req["id"], + accounts_count=len(req.get("accounts", [])), + ) + ) + return APIResponse( success=True, data=connections, - message=f"Found {len(connections)} bank connections" + message=f"Found {len(connections)} bank connections", ) - + except Exception as e: logger.error(f"Failed to get bank connection status: {e}") - raise HTTPException(status_code=500, detail=f"Failed to get bank status: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to get bank status: {str(e)}" + ) @router.delete("/banks/connections/{requisition_id}", response_model=APIResponse) @@ -116,12 +125,14 @@ async def delete_bank_connection(requisition_id: str) -> APIResponse: # For now, return success return APIResponse( success=True, - message=f"Bank connection {requisition_id} deleted successfully" + message=f"Bank connection {requisition_id} deleted successfully", ) - + except Exception as e: logger.error(f"Failed to delete bank connection {requisition_id}: {e}") - raise HTTPException(status_code=500, detail=f"Failed to delete connection: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to delete connection: {str(e)}" + ) @router.get("/banks/countries", response_model=APIResponse) @@ -160,9 +171,9 @@ async def get_supported_countries() -> APIResponse: {"code": "SE", "name": "Sweden"}, {"code": "GB", "name": "United Kingdom"}, ] - + return APIResponse( success=True, data=countries, - message="Supported countries retrieved successfully" - ) \ No newline at end of file + message="Supported countries retrieved successfully", + ) diff --git a/leggend/api/routes/notifications.py b/leggend/api/routes/notifications.py index 5761f7f..6bc6558 100644 --- a/leggend/api/routes/notifications.py +++ b/leggend/api/routes/notifications.py @@ -4,11 +4,11 @@ from loguru import logger from leggend.api.models.common import APIResponse from leggend.api.models.notifications import ( - NotificationSettings, - NotificationTest, + NotificationSettings, + NotificationTest, DiscordConfig, TelegramConfig, - NotificationFilters + NotificationFilters, ) from leggend.services.notification_service import NotificationService from leggend.config import config @@ -23,38 +23,44 @@ async def get_notification_settings() -> APIResponse: try: notifications_config = config.notifications_config filters_config = config.filters_config - + # Build response safely without exposing secrets discord_config = notifications_config.get("discord", {}) telegram_config = notifications_config.get("telegram", {}) - + settings = NotificationSettings( discord=DiscordConfig( webhook="***" if discord_config.get("webhook") else "", - enabled=discord_config.get("enabled", True) - ) if discord_config.get("webhook") else None, + enabled=discord_config.get("enabled", True), + ) + if discord_config.get("webhook") + else None, telegram=TelegramConfig( token="***" if telegram_config.get("token") else "", chat_id=telegram_config.get("chat_id", 0), - enabled=telegram_config.get("enabled", True) - ) if telegram_config.get("token") else None, + enabled=telegram_config.get("enabled", True), + ) + if telegram_config.get("token") + else None, filters=NotificationFilters( case_insensitive=filters_config.get("case-insensitive", {}), case_sensitive=filters_config.get("case-sensitive"), amount_threshold=filters_config.get("amount_threshold"), - keywords=filters_config.get("keywords", []) - ) + keywords=filters_config.get("keywords", []), + ), ) - + return APIResponse( success=True, data=settings, - message="Notification settings retrieved successfully" + message="Notification settings retrieved successfully", ) - + except Exception as e: logger.error(f"Failed to get notification settings: {e}") - raise HTTPException(status_code=500, detail=f"Failed to get notification settings: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to get notification settings: {str(e)}" + ) @router.put("/notifications/settings", response_model=APIResponse) @@ -63,20 +69,20 @@ async def update_notification_settings(settings: NotificationSettings) -> APIRes try: # Update notifications config notifications_config = {} - + if settings.discord: notifications_config["discord"] = { "webhook": settings.discord.webhook, - "enabled": settings.discord.enabled + "enabled": settings.discord.enabled, } - + if settings.telegram: notifications_config["telegram"] = { "token": settings.telegram.token, "chat_id": settings.telegram.chat_id, - "enabled": settings.telegram.enabled + "enabled": settings.telegram.enabled, } - + # Update filters config filters_config = {} if settings.filters.case_insensitive: @@ -87,22 +93,24 @@ async def update_notification_settings(settings: NotificationSettings) -> APIRes filters_config["amount_threshold"] = settings.filters.amount_threshold if settings.filters.keywords: filters_config["keywords"] = settings.filters.keywords - + # Save to config if notifications_config: config.update_section("notifications", notifications_config) if filters_config: config.update_section("filters", filters_config) - + return APIResponse( success=True, data={"updated": True}, - message="Notification settings updated successfully" + message="Notification settings updated successfully", ) - + except Exception as e: logger.error(f"Failed to update notification settings: {e}") - raise HTTPException(status_code=500, detail=f"Failed to update notification settings: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to update notification settings: {str(e)}" + ) @router.post("/notifications/test", response_model=APIResponse) @@ -110,25 +118,26 @@ async def test_notification(test_request: NotificationTest) -> APIResponse: """Send a test notification""" try: success = await notification_service.send_test_notification( - test_request.service, - test_request.message + test_request.service, test_request.message ) - + if success: return APIResponse( success=True, data={"sent": True}, - message=f"Test notification sent to {test_request.service} successfully" + message=f"Test notification sent to {test_request.service} successfully", ) else: return APIResponse( success=False, - message=f"Failed to send test notification to {test_request.service}" + message=f"Failed to send test notification to {test_request.service}", ) - + except Exception as e: logger.error(f"Failed to send test notification: {e}") - raise HTTPException(status_code=500, detail=f"Failed to send test notification: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to send test notification: {str(e)}" + ) @router.get("/notifications/services", response_model=APIResponse) @@ -136,37 +145,41 @@ async def get_notification_services() -> APIResponse: """Get available notification services and their status""" try: notifications_config = config.notifications_config - + services = { "discord": { "name": "Discord", "enabled": bool(notifications_config.get("discord", {}).get("webhook")), - "configured": bool(notifications_config.get("discord", {}).get("webhook")), - "active": notifications_config.get("discord", {}).get("enabled", True) + "configured": bool( + notifications_config.get("discord", {}).get("webhook") + ), + "active": notifications_config.get("discord", {}).get("enabled", True), }, "telegram": { - "name": "Telegram", + "name": "Telegram", "enabled": bool( - notifications_config.get("telegram", {}).get("token") and - notifications_config.get("telegram", {}).get("chat_id") + notifications_config.get("telegram", {}).get("token") + and notifications_config.get("telegram", {}).get("chat_id") ), "configured": bool( - notifications_config.get("telegram", {}).get("token") and - notifications_config.get("telegram", {}).get("chat_id") + notifications_config.get("telegram", {}).get("token") + and notifications_config.get("telegram", {}).get("chat_id") ), - "active": notifications_config.get("telegram", {}).get("enabled", True) - } + "active": notifications_config.get("telegram", {}).get("enabled", True), + }, } - + return APIResponse( success=True, data=services, - message="Notification services status retrieved successfully" + message="Notification services status retrieved successfully", ) - + except Exception as e: logger.error(f"Failed to get notification services: {e}") - raise HTTPException(status_code=500, detail=f"Failed to get notification services: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to get notification services: {str(e)}" + ) @router.delete("/notifications/settings/{service}", response_model=APIResponse) @@ -174,19 +187,23 @@ async def delete_notification_service(service: str) -> APIResponse: """Delete/disable a notification service""" try: if service not in ["discord", "telegram"]: - raise HTTPException(status_code=400, detail="Service must be 'discord' or 'telegram'") - + raise HTTPException( + status_code=400, detail="Service must be 'discord' or 'telegram'" + ) + notifications_config = config.notifications_config.copy() if service in notifications_config: del notifications_config[service] config.update_section("notifications", notifications_config) - + return APIResponse( success=True, data={"deleted": service}, - message=f"{service.capitalize()} notification service deleted successfully" + message=f"{service.capitalize()} notification service deleted successfully", ) - + except Exception as e: logger.error(f"Failed to delete notification service {service}: {e}") - raise HTTPException(status_code=500, detail=f"Failed to delete notification service: {str(e)}") \ No newline at end of file + raise HTTPException( + status_code=500, detail=f"Failed to delete notification service: {str(e)}" + ) diff --git a/leggend/api/routes/sync.py b/leggend/api/routes/sync.py index 837b07b..a7f1e92 100644 --- a/leggend/api/routes/sync.py +++ b/leggend/api/routes/sync.py @@ -17,27 +17,26 @@ async def get_sync_status() -> APIResponse: """Get current sync status""" try: status = await sync_service.get_sync_status() - + # Add scheduler information next_sync_time = scheduler.get_next_sync_time() if next_sync_time: status.next_sync = next_sync_time - + return APIResponse( - success=True, - data=status, - message="Sync status retrieved successfully" + success=True, data=status, message="Sync status retrieved successfully" ) - + except Exception as e: logger.error(f"Failed to get sync status: {e}") - raise HTTPException(status_code=500, detail=f"Failed to get sync status: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to get sync status: {str(e)}" + ) @router.post("/sync", response_model=APIResponse) async def trigger_sync( - background_tasks: BackgroundTasks, - sync_request: Optional[SyncRequest] = None + background_tasks: BackgroundTasks, sync_request: Optional[SyncRequest] = None ) -> APIResponse: """Trigger a manual sync operation""" try: @@ -46,32 +45,37 @@ async def trigger_sync( if status.is_running and not (sync_request and sync_request.force): return APIResponse( success=False, - message="Sync is already running. Use 'force: true' to override." + message="Sync is already running. Use 'force: true' to override.", ) - + # Determine what to sync if sync_request and sync_request.account_ids: # Sync specific accounts in background background_tasks.add_task( - sync_service.sync_specific_accounts, - sync_request.account_ids, - sync_request.force if sync_request else False + sync_service.sync_specific_accounts, + sync_request.account_ids, + sync_request.force if sync_request else False, + ) + message = ( + f"Started sync for {len(sync_request.account_ids)} specific accounts" ) - message = f"Started sync for {len(sync_request.account_ids)} specific accounts" else: # Sync all accounts in background background_tasks.add_task( - sync_service.sync_all_accounts, - sync_request.force if sync_request else False + sync_service.sync_all_accounts, + sync_request.force if sync_request else False, ) message = "Started sync for all accounts" - + return APIResponse( success=True, - data={"sync_started": True, "force": sync_request.force if sync_request else False}, - message=message + data={ + "sync_started": True, + "force": sync_request.force if sync_request else False, + }, + message=message, ) - + except Exception as e: logger.error(f"Failed to trigger sync: {e}") raise HTTPException(status_code=500, detail=f"Failed to trigger sync: {str(e)}") @@ -83,20 +87,21 @@ async def sync_now(sync_request: Optional[SyncRequest] = None) -> APIResponse: try: if sync_request and sync_request.account_ids: result = await sync_service.sync_specific_accounts( - sync_request.account_ids, - sync_request.force + sync_request.account_ids, sync_request.force ) else: result = await sync_service.sync_all_accounts( sync_request.force if sync_request else False ) - + return APIResponse( success=result.success, data=result, - message="Sync completed" if result.success else f"Sync failed with {len(result.errors)} errors" + message="Sync completed" + if result.success + else f"Sync failed with {len(result.errors)} errors", ) - + except Exception as e: logger.error(f"Failed to run sync: {e}") raise HTTPException(status_code=500, detail=f"Failed to run sync: {str(e)}") @@ -108,22 +113,28 @@ async def get_scheduler_config() -> APIResponse: try: scheduler_config = config.scheduler_config next_sync_time = scheduler.get_next_sync_time() - + response_data = { **scheduler_config, - "next_scheduled_sync": next_sync_time.isoformat() if next_sync_time else None, - "is_running": scheduler.scheduler.running if hasattr(scheduler, 'scheduler') else False + "next_scheduled_sync": next_sync_time.isoformat() + if next_sync_time + else None, + "is_running": scheduler.scheduler.running + if hasattr(scheduler, "scheduler") + else False, } - + return APIResponse( success=True, data=response_data, - message="Scheduler configuration retrieved successfully" + message="Scheduler configuration retrieved successfully", ) - + except Exception as e: logger.error(f"Failed to get scheduler config: {e}") - raise HTTPException(status_code=500, detail=f"Failed to get scheduler config: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to get scheduler config: {str(e)}" + ) @router.put("/sync/scheduler", response_model=APIResponse) @@ -135,26 +146,32 @@ async def update_scheduler_config(scheduler_config: SchedulerConfig) -> APIRespo try: cron_parts = scheduler_config.cron.split() if len(cron_parts) != 5: - raise ValueError("Cron expression must have 5 parts: minute hour day month day_of_week") + raise ValueError( + "Cron expression must have 5 parts: minute hour day month day_of_week" + ) except Exception as e: - raise HTTPException(status_code=400, detail=f"Invalid cron expression: {str(e)}") - + raise HTTPException( + status_code=400, detail=f"Invalid cron expression: {str(e)}" + ) + # Update configuration schedule_data = scheduler_config.dict(exclude_none=True) config.update_section("scheduler", {"sync": schedule_data}) - + # Reschedule the job scheduler.reschedule_sync(schedule_data) - + return APIResponse( success=True, data=schedule_data, - message="Scheduler configuration updated successfully" + message="Scheduler configuration updated successfully", ) - + except Exception as e: logger.error(f"Failed to update scheduler config: {e}") - raise HTTPException(status_code=500, detail=f"Failed to update scheduler config: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to update scheduler config: {str(e)}" + ) @router.post("/sync/scheduler/start", response_model=APIResponse) @@ -163,37 +180,29 @@ async def start_scheduler() -> APIResponse: try: if not scheduler.scheduler.running: scheduler.start() - return APIResponse( - success=True, - message="Scheduler started successfully" - ) + return APIResponse(success=True, message="Scheduler started successfully") else: - return APIResponse( - success=True, - message="Scheduler is already running" - ) - + return APIResponse(success=True, message="Scheduler is already running") + except Exception as e: logger.error(f"Failed to start scheduler: {e}") - raise HTTPException(status_code=500, detail=f"Failed to start scheduler: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to start scheduler: {str(e)}" + ) -@router.post("/sync/scheduler/stop", response_model=APIResponse) +@router.post("/sync/scheduler/stop", response_model=APIResponse) async def stop_scheduler() -> APIResponse: """Stop the background scheduler""" try: if scheduler.scheduler.running: scheduler.shutdown() - return APIResponse( - success=True, - message="Scheduler stopped successfully" - ) + return APIResponse(success=True, message="Scheduler stopped successfully") else: - return APIResponse( - success=True, - message="Scheduler is already stopped" - ) - + return APIResponse(success=True, message="Scheduler is already stopped") + except Exception as e: logger.error(f"Failed to stop scheduler: {e}") - raise HTTPException(status_code=500, detail=f"Failed to stop scheduler: {str(e)}") \ No newline at end of file + raise HTTPException( + status_code=500, detail=f"Failed to stop scheduler: {str(e)}" + ) diff --git a/leggend/api/routes/transactions.py b/leggend/api/routes/transactions.py index ff8ada0..633e9a0 100644 --- a/leggend/api/routes/transactions.py +++ b/leggend/api/routes/transactions.py @@ -17,95 +17,111 @@ database_service = DatabaseService() async def get_all_transactions( limit: Optional[int] = Query(default=100, le=500), offset: Optional[int] = Query(default=0, ge=0), - summary_only: bool = Query(default=True, description="Return transaction summaries only"), - date_from: Optional[str] = Query(default=None, description="Filter from date (YYYY-MM-DD)"), - date_to: Optional[str] = Query(default=None, description="Filter to date (YYYY-MM-DD)"), - min_amount: Optional[float] = Query(default=None, description="Minimum transaction amount"), - max_amount: Optional[float] = Query(default=None, description="Maximum transaction amount"), - search: Optional[str] = Query(default=None, description="Search in transaction descriptions"), - account_id: Optional[str] = Query(default=None, description="Filter by account ID") + summary_only: bool = Query( + default=True, description="Return transaction summaries only" + ), + date_from: Optional[str] = Query( + default=None, description="Filter from date (YYYY-MM-DD)" + ), + date_to: Optional[str] = Query( + default=None, description="Filter to date (YYYY-MM-DD)" + ), + min_amount: Optional[float] = Query( + default=None, description="Minimum transaction amount" + ), + max_amount: Optional[float] = Query( + default=None, description="Maximum transaction amount" + ), + search: Optional[str] = Query( + default=None, description="Search in transaction descriptions" + ), + account_id: Optional[str] = Query(default=None, description="Filter by account ID"), ) -> APIResponse: """Get all transactions across all accounts with filtering options""" try: # Get all requisitions and accounts requisitions_data = await gocardless_service.get_requisitions() all_accounts = set() - + for req in requisitions_data.get("results", []): all_accounts.update(req.get("accounts", [])) - + # Filter by specific account if requested if account_id: if account_id not in all_accounts: raise HTTPException(status_code=404, detail="Account not found") all_accounts = {account_id} - + all_transactions = [] - + # Collect transactions from all accounts for acc_id in all_accounts: try: account_details = await gocardless_service.get_account_details(acc_id) - transactions_data = await gocardless_service.get_account_transactions(acc_id) - + transactions_data = await gocardless_service.get_account_transactions( + acc_id + ) + processed_transactions = database_service.process_transactions( acc_id, account_details, transactions_data ) all_transactions.extend(processed_transactions) - + except Exception as e: logger.error(f"Failed to get transactions for account {acc_id}: {e}") continue - + # Apply filters filtered_transactions = all_transactions - + # Date range filter if date_from: from_date = datetime.fromisoformat(date_from) filtered_transactions = [ - txn for txn in filtered_transactions + txn + for txn in filtered_transactions if txn["transactionDate"] >= from_date ] - + if date_to: to_date = datetime.fromisoformat(date_to) filtered_transactions = [ - txn for txn in filtered_transactions + txn + for txn in filtered_transactions if txn["transactionDate"] <= to_date ] - + # Amount filters if min_amount is not None: filtered_transactions = [ - txn for txn in filtered_transactions + txn + for txn in filtered_transactions if txn["transactionValue"] >= min_amount ] - + if max_amount is not None: filtered_transactions = [ - txn for txn in filtered_transactions + txn + for txn in filtered_transactions if txn["transactionValue"] <= max_amount ] - + # Search filter if search: search_lower = search.lower() filtered_transactions = [ - txn for txn in filtered_transactions + txn + for txn in filtered_transactions if search_lower in txn["description"].lower() ] - + # Sort by date (newest first) - filtered_transactions.sort( - key=lambda x: x["transactionDate"], - reverse=True - ) - + filtered_transactions.sort(key=lambda x: x["transactionDate"], reverse=True) + # Apply pagination total_transactions = len(filtered_transactions) - paginated_transactions = filtered_transactions[offset:offset + limit] - + paginated_transactions = filtered_transactions[offset : offset + limit] + if summary_only: # Return simplified transaction summaries data = [ @@ -116,7 +132,7 @@ async def get_all_transactions( amount=txn["transactionValue"], currency=txn["transactionCurrency"], status=txn["transactionStatus"], - account_id=txn["accountId"] + account_id=txn["accountId"], ) for txn in paginated_transactions ] @@ -133,86 +149,99 @@ async def get_all_transactions( transaction_value=txn["transactionValue"], transaction_currency=txn["transactionCurrency"], transaction_status=txn["transactionStatus"], - raw_transaction=txn["rawTransaction"] + raw_transaction=txn["rawTransaction"], ) for txn in paginated_transactions ] - + return APIResponse( success=True, data=data, - message=f"Retrieved {len(data)} transactions (showing {offset + 1}-{offset + len(data)} of {total_transactions})" + message=f"Retrieved {len(data)} transactions (showing {offset + 1}-{offset + len(data)} of {total_transactions})", ) - + except Exception as e: logger.error(f"Failed to get transactions: {e}") - raise HTTPException(status_code=500, detail=f"Failed to get transactions: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to get transactions: {str(e)}" + ) @router.get("/transactions/stats", response_model=APIResponse) async def get_transaction_stats( days: int = Query(default=30, description="Number of days to include in stats"), - account_id: Optional[str] = Query(default=None, description="Filter by account ID") + account_id: Optional[str] = Query(default=None, description="Filter by account ID"), ) -> APIResponse: """Get transaction statistics for the last N days""" try: # Date range for stats end_date = datetime.now() start_date = end_date - timedelta(days=days) - + # Get all transactions (reuse the existing endpoint logic) # This is a simplified implementation - in practice you might want to optimize this requisitions_data = await gocardless_service.get_requisitions() all_accounts = set() - + for req in requisitions_data.get("results", []): all_accounts.update(req.get("accounts", [])) - + if account_id: if account_id not in all_accounts: raise HTTPException(status_code=404, detail="Account not found") all_accounts = {account_id} - + all_transactions = [] - + for acc_id in all_accounts: try: account_details = await gocardless_service.get_account_details(acc_id) - transactions_data = await gocardless_service.get_account_transactions(acc_id) - + transactions_data = await gocardless_service.get_account_transactions( + acc_id + ) + processed_transactions = database_service.process_transactions( acc_id, account_details, transactions_data ) all_transactions.extend(processed_transactions) - + except Exception as e: logger.error(f"Failed to get transactions for account {acc_id}: {e}") continue - + # Filter transactions by date range recent_transactions = [ - txn for txn in all_transactions + txn + for txn in all_transactions if start_date <= txn["transactionDate"] <= end_date ] - + # Calculate stats total_transactions = len(recent_transactions) total_income = sum( - txn["transactionValue"] - for txn in recent_transactions + txn["transactionValue"] + for txn in recent_transactions if txn["transactionValue"] > 0 ) total_expenses = sum( - abs(txn["transactionValue"]) - for txn in recent_transactions + abs(txn["transactionValue"]) + for txn in recent_transactions if txn["transactionValue"] < 0 ) net_change = total_income - total_expenses - + # Count by status - booked_count = len([txn for txn in recent_transactions if txn["transactionStatus"] == "booked"]) - pending_count = len([txn for txn in recent_transactions if txn["transactionStatus"] == "pending"]) - + booked_count = len( + [txn for txn in recent_transactions if txn["transactionStatus"] == "booked"] + ) + pending_count = len( + [ + txn + for txn in recent_transactions + if txn["transactionStatus"] == "pending" + ] + ) + stats = { "period_days": days, "total_transactions": total_transactions, @@ -222,17 +251,23 @@ async def get_transaction_stats( "total_expenses": round(total_expenses, 2), "net_change": round(net_change, 2), "average_transaction": round( - sum(txn["transactionValue"] for txn in recent_transactions) / total_transactions, 2 - ) if total_transactions > 0 else 0, - "accounts_included": len(all_accounts) + sum(txn["transactionValue"] for txn in recent_transactions) + / total_transactions, + 2, + ) + if total_transactions > 0 + else 0, + "accounts_included": len(all_accounts), } - + return APIResponse( success=True, data=stats, - message=f"Transaction statistics for last {days} days" + message=f"Transaction statistics for last {days} days", ) - + except Exception as e: logger.error(f"Failed to get transaction stats: {e}") - raise HTTPException(status_code=500, detail=f"Failed to get transaction stats: {str(e)}") \ No newline at end of file + raise HTTPException( + status_code=500, detail=f"Failed to get transaction stats: {str(e)}" + ) diff --git a/leggend/background/__init__.py b/leggend/background/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/leggend/background/scheduler.py b/leggend/background/scheduler.py index 6c662ed..e2a9648 100644 --- a/leggend/background/scheduler.py +++ b/leggend/background/scheduler.py @@ -4,47 +4,30 @@ from loguru import logger from leggend.config import config from leggend.services.sync_service import SyncService +from leggend.services.notification_service import NotificationService class BackgroundScheduler: def __init__(self): self.scheduler = AsyncIOScheduler() self.sync_service = SyncService() + self.notification_service = NotificationService() + self.max_retries = 3 + self.retry_delay = 300 # 5 minutes def start(self): """Start the scheduler and configure sync jobs based on configuration""" schedule_config = config.scheduler_config.get("sync", {}) - + if not schedule_config.get("enabled", True): logger.info("Sync scheduling is disabled in configuration") self.scheduler.start() return - # Use custom cron expression if provided, otherwise use hour/minute - if schedule_config.get("cron"): - # Parse custom cron expression (e.g., "0 3 * * *" for daily at 3 AM) - try: - cron_parts = schedule_config["cron"].split() - if len(cron_parts) == 5: - minute, hour, day, month, day_of_week = cron_parts - trigger = CronTrigger( - minute=minute, - hour=hour, - day=day if day != "*" else None, - month=month if month != "*" else None, - day_of_week=day_of_week if day_of_week != "*" else None, - ) - else: - logger.error(f"Invalid cron expression: {schedule_config['cron']}") - return - except Exception as e: - logger.error(f"Error parsing cron expression: {e}") - return - else: - # Use hour/minute configuration (default: 3:00 AM daily) - hour = schedule_config.get("hour", 3) - minute = schedule_config.get("minute", 0) - trigger = CronTrigger(hour=hour, minute=minute) + # Parse schedule configuration + trigger = self._parse_cron_config(schedule_config) + if not trigger: + return self.scheduler.add_job( self._run_sync, @@ -53,7 +36,7 @@ class BackgroundScheduler: name="Scheduled sync of all transactions", max_instances=1, ) - + self.scheduler.start() logger.info(f"Background scheduler started with sync job: {trigger}") @@ -76,28 +59,9 @@ class BackgroundScheduler: return # Configure new schedule - if schedule_config.get("cron"): - try: - cron_parts = schedule_config["cron"].split() - if len(cron_parts) == 5: - minute, hour, day, month, day_of_week = cron_parts - trigger = CronTrigger( - minute=minute, - hour=hour, - day=day if day != "*" else None, - month=month if month != "*" else None, - day_of_week=day_of_week if day_of_week != "*" else None, - ) - else: - logger.error(f"Invalid cron expression: {schedule_config['cron']}") - return - except Exception as e: - logger.error(f"Error parsing cron expression: {e}") - return - else: - hour = schedule_config.get("hour", 3) - minute = schedule_config.get("minute", 0) - trigger = CronTrigger(hour=hour, minute=minute) + trigger = self._parse_cron_config(schedule_config) + if not trigger: + return self.scheduler.add_job( self._run_sync, @@ -108,13 +72,90 @@ class BackgroundScheduler: ) logger.info(f"Rescheduled sync job with: {trigger}") - async def _run_sync(self): + def _parse_cron_config(self, schedule_config: dict) -> CronTrigger: + """Parse cron configuration and return CronTrigger""" + if schedule_config.get("cron"): + # Parse custom cron expression (e.g., "0 3 * * *" for daily at 3 AM) + try: + cron_parts = schedule_config["cron"].split() + if len(cron_parts) == 5: + minute, hour, day, month, day_of_week = cron_parts + return CronTrigger( + minute=minute, + hour=hour, + day=day if day != "*" else None, + month=month if month != "*" else None, + day_of_week=day_of_week if day_of_week != "*" else None, + ) + else: + logger.error(f"Invalid cron expression: {schedule_config['cron']}") + return None + except Exception as e: + logger.error(f"Error parsing cron expression: {e}") + return None + else: + # Use hour/minute configuration (default: 3:00 AM daily) + hour = schedule_config.get("hour", 3) + minute = schedule_config.get("minute", 0) + return CronTrigger(hour=hour, minute=minute) + + async def _run_sync(self, retry_count: int = 0): + """Run sync with enhanced error handling and retry logic""" try: logger.info("Starting scheduled sync job") await self.sync_service.sync_all_accounts() logger.info("Scheduled sync job completed successfully") except Exception as e: - logger.error(f"Scheduled sync job failed: {e}") + logger.error( + f"Scheduled sync job failed (attempt {retry_count + 1}/{self.max_retries}): {e}" + ) + + # Send notification about the failure + try: + await self.notification_service.send_expiry_notification( + { + "type": "sync_failure", + "error": str(e), + "retry_count": retry_count + 1, + "max_retries": self.max_retries, + } + ) + except Exception as notification_error: + logger.error( + f"Failed to send failure notification: {notification_error}" + ) + + # Implement retry logic for transient failures + if retry_count < self.max_retries - 1: + import datetime + + logger.info(f"Retrying sync job in {self.retry_delay} seconds...") + # Schedule a retry + retry_time = datetime.datetime.now() + datetime.timedelta( + seconds=self.retry_delay + ) + self.scheduler.add_job( + self._run_sync, + "date", + args=[retry_count + 1], + id=f"sync_retry_{retry_count + 1}", + run_date=retry_time, + ) + else: + logger.error("Maximum retries exceeded for sync job") + # Send final failure notification + try: + await self.notification_service.send_expiry_notification( + { + "type": "sync_final_failure", + "error": str(e), + "retry_count": retry_count + 1, + } + ) + except Exception as notification_error: + logger.error( + f"Failed to send final failure notification: {notification_error}" + ) def get_next_sync_time(self): """Get the next scheduled sync time""" @@ -124,4 +165,4 @@ class BackgroundScheduler: return None -scheduler = BackgroundScheduler() \ No newline at end of file +scheduler = BackgroundScheduler() diff --git a/leggend/config.py b/leggend/config.py index 71af828..b09e180 100644 --- a/leggend/config.py +++ b/leggend/config.py @@ -24,7 +24,7 @@ class Config: if config_path is None: config_path = os.environ.get( "LEGGEN_CONFIG_FILE", - str(Path.home() / ".config" / "leggen" / "config.toml") + str(Path.home() / ".config" / "leggen" / "config.toml"), ) self._config_path = config_path @@ -42,15 +42,17 @@ class Config: return self._config - def save_config(self, config_data: Dict[str, Any] = None, config_path: str = None) -> None: + def save_config( + self, config_data: Dict[str, Any] = None, config_path: str = None + ) -> None: """Save configuration to TOML file""" if config_data is None: config_data = self._config - + if config_path is None: config_path = self._config_path or os.environ.get( "LEGGEN_CONFIG_FILE", - str(Path.home() / ".config" / "leggen" / "config.toml") + str(Path.home() / ".config" / "leggen" / "config.toml"), ) # Ensure directory exists @@ -59,7 +61,7 @@ class Config: try: with open(config_path, "wb") as f: tomli_w.dump(config_data, f) - + # Update in-memory config self._config = config_data self._config_path = config_path @@ -72,10 +74,10 @@ class Config: """Update a specific configuration value""" if self._config is None: self.load_config() - + if section not in self._config: self._config[section] = {} - + self._config[section][key] = value self.save_config() @@ -83,7 +85,7 @@ class Config: """Update an entire configuration section""" if self._config is None: self.load_config() - + self._config[section] = data self.save_config() @@ -117,10 +119,10 @@ class Config: "enabled": True, "hour": 3, "minute": 0, - "cron": None # Optional custom cron expression + "cron": None, # Optional custom cron expression } } return self.config.get("scheduler", default_schedule) -config = Config() \ No newline at end of file +config = Config() diff --git a/leggend/main.py b/leggend/main.py index 7ca0615..1693803 100644 --- a/leggend/main.py +++ b/leggend/main.py @@ -1,5 +1,6 @@ import asyncio from contextlib import asynccontextmanager +from importlib import metadata import uvicorn from fastapi import FastAPI @@ -15,7 +16,7 @@ from leggend.config import config async def lifespan(app: FastAPI): # Startup logger.info("Starting leggend service...") - + # Load configuration try: config.load_config() @@ -27,26 +28,35 @@ async def lifespan(app: FastAPI): # Start background scheduler scheduler.start() logger.info("Background scheduler started") - + yield - + # Shutdown logger.info("Shutting down leggend service...") scheduler.shutdown() def create_app() -> FastAPI: + # Get version dynamically from package metadata + try: + version = metadata.version("leggen") + except metadata.PackageNotFoundError: + version = "unknown" + app = FastAPI( title="Leggend API", description="Open Banking API for Leggen", - version="0.6.11", + version=version, lifespan=lifespan, ) # Add CORS middleware app.add_middleware( CORSMiddleware, - allow_origins=["http://localhost:3000", "http://localhost:5173"], # SvelteKit dev servers + allow_origins=[ + "http://localhost:3000", + "http://localhost:5173", + ], # SvelteKit dev servers allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -60,7 +70,12 @@ def create_app() -> FastAPI: @app.get("/") async def root(): - return {"message": "Leggend API is running", "version": "0.6.11"} + # Get version dynamically + try: + version = metadata.version("leggen") + except metadata.PackageNotFoundError: + version = "unknown" + return {"message": "Leggend API is running", "version": version} @app.get("/health") async def health(): @@ -71,25 +86,19 @@ def create_app() -> FastAPI: def main(): import argparse + parser = argparse.ArgumentParser(description="Start the Leggend API service") parser.add_argument( - "--reload", - action="store_true", - help="Enable auto-reload for development" + "--reload", action="store_true", help="Enable auto-reload for development" ) parser.add_argument( - "--host", - default="0.0.0.0", - help="Host to bind to (default: 0.0.0.0)" + "--host", default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)" ) parser.add_argument( - "--port", - type=int, - default=8000, - help="Port to bind to (default: 8000)" + "--port", type=int, default=8000, help="Port to bind to (default: 8000)" ) args = parser.parse_args() - + if args.reload: # Use string import for reload to work properly uvicorn.run( @@ -114,4 +123,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/leggend/services/__init__.py b/leggend/services/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/leggend/services/database_service.py b/leggend/services/database_service.py index d7ed04d..87b7911 100644 --- a/leggend/services/database_service.py +++ b/leggend/services/database_service.py @@ -11,7 +11,9 @@ class DatabaseService: self.db_config = config.database_config self.sqlite_enabled = self.db_config.get("sqlite", True) - async def persist_balance(self, account_id: str, balance_data: Dict[str, Any]) -> None: + async def persist_balance( + self, account_id: str, balance_data: Dict[str, Any] + ) -> None: """Persist account balance data""" if not self.sqlite_enabled: logger.warning("SQLite database disabled, skipping balance persistence") @@ -19,7 +21,9 @@ class DatabaseService: await self._persist_balance_sqlite(account_id, balance_data) - async def persist_transactions(self, account_id: str, transactions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + async def persist_transactions( + self, account_id: str, transactions: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: """Persist transactions and return new transactions""" if not self.sqlite_enabled: logger.warning("SQLite database disabled, skipping transaction persistence") @@ -27,32 +31,48 @@ class DatabaseService: return await self._persist_transactions_sqlite(account_id, transactions) - def process_transactions(self, account_id: str, account_info: Dict[str, Any], transaction_data: Dict[str, Any]) -> List[Dict[str, Any]]: + def process_transactions( + self, + account_id: str, + account_info: Dict[str, Any], + transaction_data: Dict[str, Any], + ) -> List[Dict[str, Any]]: """Process raw transaction data into standardized format""" transactions = [] - + # Process booked transactions for transaction in transaction_data.get("transactions", {}).get("booked", []): - processed = self._process_single_transaction(account_id, account_info, transaction, "booked") + processed = self._process_single_transaction( + account_id, account_info, transaction, "booked" + ) transactions.append(processed) - # Process pending transactions + # Process pending transactions for transaction in transaction_data.get("transactions", {}).get("pending", []): - processed = self._process_single_transaction(account_id, account_info, transaction, "pending") + processed = self._process_single_transaction( + account_id, account_info, transaction, "pending" + ) transactions.append(processed) return transactions - def _process_single_transaction(self, account_id: str, account_info: Dict[str, Any], transaction: Dict[str, Any], status: str) -> Dict[str, Any]: + def _process_single_transaction( + self, + account_id: str, + account_info: Dict[str, Any], + transaction: Dict[str, Any], + status: str, + ) -> Dict[str, Any]: """Process a single transaction into standardized format""" # Extract dates - booked_date = transaction.get("bookingDateTime") or transaction.get("bookingDate") + booked_date = transaction.get("bookingDateTime") or transaction.get( + "bookingDate" + ) value_date = transaction.get("valueDateTime") or transaction.get("valueDate") - + if booked_date and value_date: min_date = min( - datetime.fromisoformat(booked_date), - datetime.fromisoformat(value_date) + datetime.fromisoformat(booked_date), datetime.fromisoformat(value_date) ) else: min_date = datetime.fromisoformat(booked_date or value_date) @@ -65,7 +85,7 @@ class DatabaseService: # Extract description description = transaction.get( "remittanceInformationUnstructured", - ",".join(transaction.get("remittanceInformationUnstructuredArray", [])) + ",".join(transaction.get("remittanceInformationUnstructuredArray", [])), ) return { @@ -81,13 +101,19 @@ class DatabaseService: "rawTransaction": transaction, } - async def _persist_balance_sqlite(self, account_id: str, balance_data: Dict[str, Any]) -> None: + async def _persist_balance_sqlite( + self, account_id: str, balance_data: Dict[str, Any] + ) -> None: """Persist balance to SQLite - placeholder implementation""" # Would import and use leggen.database.sqlite logger.info(f"Persisting balance to SQLite for account {account_id}") - async def _persist_transactions_sqlite(self, account_id: str, transactions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + async def _persist_transactions_sqlite( + self, account_id: str, transactions: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: """Persist transactions to SQLite - placeholder implementation""" # Would import and use leggen.database.sqlite - logger.info(f"Persisting {len(transactions)} transactions to SQLite for account {account_id}") - return transactions # Return new transactions for notifications \ No newline at end of file + logger.info( + f"Persisting {len(transactions)} transactions to SQLite for account {account_id}" + ) + return transactions # Return new transactions for notifications diff --git a/leggend/services/gocardless_service.py b/leggend/services/gocardless_service.py index 4a44c85..073c83a 100644 --- a/leggend/services/gocardless_service.py +++ b/leggend/services/gocardless_service.py @@ -12,37 +12,36 @@ from leggend.config import config class GoCardlessService: def __init__(self): self.config = config.gocardless_config - self.base_url = self.config.get("url", "https://bankaccountdata.gocardless.com/api/v2") + self.base_url = self.config.get( + "url", "https://bankaccountdata.gocardless.com/api/v2" + ) self._token = None async def _get_auth_headers(self) -> Dict[str, str]: """Get authentication headers for GoCardless API""" token = await self._get_token() - return { - "Authorization": f"Bearer {token}", - "Content-Type": "application/json" - } + return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} async def _get_token(self) -> str: """Get access token for GoCardless API""" if self._token: return self._token - + # Use ~/.config/leggen for consistency with main config auth_file = Path.home() / ".config" / "leggen" / "auth.json" - + if auth_file.exists(): try: with open(auth_file, "r") as f: auth = json.load(f) - + if auth.get("access"): # Try to refresh the token async with httpx.AsyncClient() as client: try: response = await client.post( f"{self.base_url}/token/refresh/", - json={"refresh": auth["refresh"]} + json={"refresh": auth["refresh"]}, ) response.raise_for_status() auth.update(response.json()) @@ -84,7 +83,7 @@ class GoCardlessService: """Save authentication data to file""" auth_file = Path.home() / ".config" / "leggen" / "auth.json" auth_file.parent.mkdir(parents=True, exist_ok=True) - + with open(auth_file, "w") as f: json.dump(auth_data, f) @@ -95,22 +94,21 @@ class GoCardlessService: response = await client.get( f"{self.base_url}/institutions/", headers=headers, - params={"country": country} + params={"country": country}, ) response.raise_for_status() return response.json() - async def create_requisition(self, institution_id: str, redirect_url: str) -> Dict[str, Any]: + async def create_requisition( + self, institution_id: str, redirect_url: str + ) -> Dict[str, Any]: """Create a bank connection requisition""" headers = await self._get_auth_headers() async with httpx.AsyncClient() as client: response = await client.post( f"{self.base_url}/requisitions/", headers=headers, - json={ - "institution_id": institution_id, - "redirect": redirect_url - } + json={"institution_id": institution_id, "redirect": redirect_url}, ) response.raise_for_status() return response.json() @@ -120,8 +118,7 @@ class GoCardlessService: headers = await self._get_auth_headers() async with httpx.AsyncClient() as client: response = await client.get( - f"{self.base_url}/requisitions/", - headers=headers + f"{self.base_url}/requisitions/", headers=headers ) response.raise_for_status() return response.json() @@ -131,8 +128,7 @@ class GoCardlessService: headers = await self._get_auth_headers() async with httpx.AsyncClient() as client: response = await client.get( - f"{self.base_url}/accounts/{account_id}/", - headers=headers + f"{self.base_url}/accounts/{account_id}/", headers=headers ) response.raise_for_status() return response.json() @@ -142,8 +138,7 @@ class GoCardlessService: headers = await self._get_auth_headers() async with httpx.AsyncClient() as client: response = await client.get( - f"{self.base_url}/accounts/{account_id}/balances/", - headers=headers + f"{self.base_url}/accounts/{account_id}/balances/", headers=headers ) response.raise_for_status() return response.json() @@ -153,8 +148,7 @@ class GoCardlessService: headers = await self._get_auth_headers() async with httpx.AsyncClient() as client: response = await client.get( - f"{self.base_url}/accounts/{account_id}/transactions/", - headers=headers + f"{self.base_url}/accounts/{account_id}/transactions/", headers=headers ) response.raise_for_status() - return response.json() \ No newline at end of file + return response.json() diff --git a/leggend/services/notification_service.py b/leggend/services/notification_service.py index 3cfe055..8d29fd7 100644 --- a/leggend/services/notification_service.py +++ b/leggend/services/notification_service.py @@ -10,7 +10,9 @@ class NotificationService: self.notifications_config = config.notifications_config self.filters_config = config.filters_config - async def send_transaction_notifications(self, transactions: List[Dict[str, Any]]) -> None: + async def send_transaction_notifications( + self, transactions: List[Dict[str, Any]] + ) -> None: """Send notifications for new transactions that match filters""" if not self.filters_config: logger.info("No notification filters configured, skipping notifications") @@ -18,7 +20,7 @@ class NotificationService: # Filter transactions that match notification criteria matching_transactions = self._filter_transactions(transactions) - + if not matching_transactions: logger.info("No transactions matched notification filters") return @@ -26,7 +28,7 @@ class NotificationService: # Send to enabled notification services if self._is_discord_enabled(): await self._send_discord_notifications(matching_transactions) - + if self._is_telegram_enabled(): await self._send_telegram_notifications(matching_transactions) @@ -40,7 +42,9 @@ class NotificationService: await self._send_telegram_test(message) return True else: - logger.error(f"Notification service '{service}' not enabled or not found") + logger.error( + f"Notification service '{service}' not enabled or not found" + ) return False except Exception as e: logger.error(f"Failed to send test notification to {service}: {e}") @@ -50,54 +54,66 @@ class NotificationService: """Send notification about account expiry""" if self._is_discord_enabled(): await self._send_discord_expiry(notification_data) - + if self._is_telegram_enabled(): await self._send_telegram_expiry(notification_data) - def _filter_transactions(self, transactions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _filter_transactions( + self, transactions: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: """Filter transactions based on notification criteria""" matching = [] filters_case_insensitive = self.filters_config.get("case-insensitive", {}) - + for transaction in transactions: description = transaction.get("description", "").lower() - + # Check case-insensitive filters for filter_name, filter_value in filters_case_insensitive.items(): if filter_value.lower() in description: - matching.append({ - "name": transaction["description"], - "value": transaction["transactionValue"], - "currency": transaction["transactionCurrency"], - "date": transaction["transactionDate"], - }) + matching.append( + { + "name": transaction["description"], + "value": transaction["transactionValue"], + "currency": transaction["transactionCurrency"], + "date": transaction["transactionDate"], + } + ) break - + return matching def _is_discord_enabled(self) -> bool: """Check if Discord notifications are enabled""" discord_config = self.notifications_config.get("discord", {}) - return bool(discord_config.get("webhook") and discord_config.get("enabled", True)) + return bool( + discord_config.get("webhook") and discord_config.get("enabled", True) + ) def _is_telegram_enabled(self) -> bool: """Check if Telegram notifications are enabled""" telegram_config = self.notifications_config.get("telegram", {}) return bool( - telegram_config.get("token") and - telegram_config.get("chat_id") and - telegram_config.get("enabled", True) + telegram_config.get("token") + and telegram_config.get("chat_id") + and telegram_config.get("enabled", True) ) - async def _send_discord_notifications(self, transactions: List[Dict[str, Any]]) -> None: + async def _send_discord_notifications( + self, transactions: List[Dict[str, Any]] + ) -> None: """Send Discord notifications - placeholder implementation""" # Would import and use leggen.notifications.discord logger.info(f"Sending {len(transactions)} transaction notifications to Discord") - async def _send_telegram_notifications(self, transactions: List[Dict[str, Any]]) -> None: + async def _send_telegram_notifications( + self, transactions: List[Dict[str, Any]] + ) -> None: """Send Telegram notifications - placeholder implementation""" # Would import and use leggen.notifications.telegram - logger.info(f"Sending {len(transactions)} transaction notifications to Telegram") + logger.info( + f"Sending {len(transactions)} transaction notifications to Telegram" + ) async def _send_discord_test(self, message: str) -> None: """Send Discord test notification""" @@ -113,4 +129,4 @@ class NotificationService: async def _send_telegram_expiry(self, notification_data: Dict[str, Any]) -> None: """Send Telegram expiry notification""" - logger.info(f"Sending Telegram expiry notification: {notification_data}") \ No newline at end of file + logger.info(f"Sending Telegram expiry notification: {notification_data}") diff --git a/leggend/services/sync_service.py b/leggend/services/sync_service.py index 6233c3d..ba9d14a 100644 --- a/leggend/services/sync_service.py +++ b/leggend/services/sync_service.py @@ -30,7 +30,7 @@ class SyncService: start_time = datetime.now() self._sync_status.is_running = True self._sync_status.errors = [] - + accounts_processed = 0 transactions_added = 0 transactions_updated = 0 @@ -39,22 +39,24 @@ class SyncService: try: logger.info("Starting sync of all accounts") - + # Get all requisitions and accounts requisitions = await self.gocardless.get_requisitions() all_accounts = set() - + for req in requisitions.get("results", []): all_accounts.update(req.get("accounts", [])) self._sync_status.total_accounts = len(all_accounts) - + # Process each account for account_id in all_accounts: try: # Get account details - account_details = await self.gocardless.get_account_details(account_id) - + account_details = await self.gocardless.get_account_details( + account_id + ) + # Get and save balances balances = await self.gocardless.get_account_balances(account_id) if balances: @@ -62,7 +64,9 @@ class SyncService: balances_updated += len(balances.get("balances", [])) # Get and save transactions - transactions = await self.gocardless.get_account_transactions(account_id) + transactions = await self.gocardless.get_account_transactions( + account_id + ) if transactions: processed_transactions = self.database.process_transactions( account_id, account_details, transactions @@ -71,16 +75,18 @@ class SyncService: account_id, processed_transactions ) transactions_added += len(new_transactions) - + # Send notifications for new transactions if new_transactions: - await self.notifications.send_transaction_notifications(new_transactions) + await self.notifications.send_transaction_notifications( + new_transactions + ) accounts_processed += 1 self._sync_status.accounts_synced = accounts_processed - + logger.info(f"Synced account {account_id} successfully") - + except Exception as e: error_msg = f"Failed to sync account {account_id}: {str(e)}" errors.append(error_msg) @@ -88,9 +94,9 @@ class SyncService: end_time = datetime.now() duration = (end_time - start_time).total_seconds() - + self._sync_status.last_sync = end_time - + result = SyncResult( success=len(errors) == 0, accounts_processed=accounts_processed, @@ -100,12 +106,14 @@ class SyncService: duration_seconds=duration, errors=errors, started_at=start_time, - completed_at=end_time + completed_at=end_time, + ) + + logger.info( + f"Sync completed: {accounts_processed} accounts, {transactions_added} new transactions" ) - - logger.info(f"Sync completed: {accounts_processed} accounts, {transactions_added} new transactions") return result - + except Exception as e: error_msg = f"Sync failed: {str(e)}" errors.append(error_msg) @@ -114,7 +122,9 @@ class SyncService: finally: self._sync_status.is_running = False - async def sync_specific_accounts(self, account_ids: List[str], force: bool = False) -> SyncResult: + async def sync_specific_accounts( + self, account_ids: List[str], force: bool = False + ) -> SyncResult: """Sync specific accounts""" if self._sync_status.is_running and not force: raise Exception("Sync is already running") @@ -123,12 +133,12 @@ class SyncService: # For brevity, implementing a simplified version start_time = datetime.now() self._sync_status.is_running = True - + try: # Process only specified accounts # Implementation would be similar to sync_all_accounts # but filtered to only the specified account_ids - + end_time = datetime.now() return SyncResult( success=True, @@ -139,7 +149,7 @@ class SyncService: duration_seconds=(end_time - start_time).total_seconds(), errors=[], started_at=start_time, - completed_at=end_time + completed_at=end_time, ) finally: - self._sync_status.is_running = False \ No newline at end of file + self._sync_status.is_running = False diff --git a/leggend/utils/__init__.py b/leggend/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/leggend/utils/gocardless.py b/leggend/utils/gocardless.py index 1eb5e72..13db2d3 100644 --- a/leggend/utils/gocardless.py +++ b/leggend/utils/gocardless.py @@ -7,4 +7,4 @@ REQUISITION_STATUS = { "GA": "GRANTING_ACCESS", "LN": "LINKED", "EX": "EXPIRED", -} \ No newline at end of file +} diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/conftest.py b/tests/conftest.py index 8fd77fa..ba94be4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ """Pytest configuration and shared fixtures.""" + import pytest import tempfile import json @@ -19,34 +20,27 @@ def temp_config_dir(): yield config_dir -@pytest.fixture +@pytest.fixture def mock_config(temp_config_dir): """Mock configuration for testing.""" config_data = { "gocardless": { "key": "test-key", - "secret": "test-secret", - "url": "https://bankaccountdata.gocardless.com/api/v2" + "secret": "test-secret", + "url": "https://bankaccountdata.gocardless.com/api/v2", }, - "database": { - "sqlite": True - }, - "scheduler": { - "sync": { - "enabled": True, - "hour": 3, - "minute": 0 - } - } + "database": {"sqlite": True}, + "scheduler": {"sync": {"enabled": True, "hour": 3, "minute": 0}}, } - + config_file = temp_config_dir / "config.toml" with open(config_file, "wb") as f: import tomli_w + tomli_w.dump(config_data, f) - + # Mock the config path - with patch.object(Config, 'load_config') as mock_load: + with patch.object(Config, "load_config") as mock_load: mock_load.return_value = config_data config = Config() config._config = config_data @@ -57,15 +51,12 @@ def mock_config(temp_config_dir): @pytest.fixture def mock_auth_token(temp_config_dir): """Mock GoCardless authentication token.""" - auth_data = { - "access": "mock-access-token", - "refresh": "mock-refresh-token" - } - - auth_file = temp_config_dir / "auth.json" + auth_data = {"access": "mock-access-token", "refresh": "mock-refresh-token"} + + auth_file = temp_config_dir / "auth.json" with open(auth_file, "w") as f: json.dump(auth_data, f) - + return auth_data @@ -88,17 +79,17 @@ def sample_bank_data(): { "id": "REVOLUT_REVOLT21", "name": "Revolut", - "bic": "REVOLT21", + "bic": "REVOLT21", "transaction_total_days": 90, - "countries": ["GB", "LT"] + "countries": ["GB", "LT"], }, { "id": "BANCOBPI_BBPIPTPL", "name": "Banco BPI", "bic": "BBPIPTPL", - "transaction_total_days": 90, - "countries": ["PT"] - } + "transaction_total_days": 90, + "countries": ["PT"], + }, ] @@ -107,31 +98,28 @@ def sample_account_data(): """Sample account data for testing.""" return { "id": "test-account-123", - "institution_id": "REVOLUT_REVOLT21", + "institution_id": "REVOLUT_REVOLT21", "status": "READY", "iban": "LT313250081177977789", "created": "2024-02-13T23:56:00Z", - "last_accessed": "2025-09-01T09:30:00Z" + "last_accessed": "2025-09-01T09:30:00Z", } @pytest.fixture def sample_transaction_data(): - """Sample transaction data for testing.""" + """Sample transaction data for testing.""" return { "transactions": { "booked": [ { "internalTransactionId": "txn-123", "bookingDate": "2025-09-01", - "valueDate": "2025-09-01", - "transactionAmount": { - "amount": "-10.50", - "currency": "EUR" - }, - "remittanceInformationUnstructured": "Coffee Shop Payment" + "valueDate": "2025-09-01", + "transactionAmount": {"amount": "-10.50", "currency": "EUR"}, + "remittanceInformationUnstructured": "Coffee Shop Payment", } ], - "pending": [] + "pending": [], } - } \ No newline at end of file + } diff --git a/tests/unit/test_api_accounts.py b/tests/unit/test_api_accounts.py index 6a3df0e..5151167 100644 --- a/tests/unit/test_api_accounts.py +++ b/tests/unit/test_api_accounts.py @@ -1,4 +1,5 @@ """Tests for accounts API endpoints.""" + import pytest import respx import httpx @@ -8,48 +9,47 @@ from unittest.mock import patch @pytest.mark.api class TestAccountsAPI: """Test account-related API endpoints.""" - + @respx.mock - def test_get_all_accounts_success(self, api_client, mock_config, mock_auth_token, sample_account_data): + def test_get_all_accounts_success( + self, api_client, mock_config, mock_auth_token, sample_account_data + ): """Test successful retrieval of all accounts.""" requisitions_data = { - "results": [ - { - "id": "req-123", - "accounts": ["test-account-123"] - } - ] + "results": [{"id": "req-123", "accounts": ["test-account-123"]}] } - + balances_data = { "balances": [ { "balanceAmount": {"amount": "100.50", "currency": "EUR"}, "balanceType": "interimAvailable", - "lastChangeDateTime": "2025-09-01T09:30:00Z" + "lastChangeDateTime": "2025-09-01T09:30:00Z", } ] } - + # Mock GoCardless token creation respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( - return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) + return_value=httpx.Response( + 200, json={"access": "test-token", "refresh": "test-refresh"} + ) ) - + # Mock GoCardless API calls respx.get("https://bankaccountdata.gocardless.com/api/v2/requisitions/").mock( return_value=httpx.Response(200, json=requisitions_data) ) - respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/").mock( - return_value=httpx.Response(200, json=sample_account_data) - ) - respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/").mock( - return_value=httpx.Response(200, json=balances_data) - ) - - with patch('leggend.config.config', mock_config): + respx.get( + "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/" + ).mock(return_value=httpx.Response(200, json=sample_account_data)) + respx.get( + "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/" + ).mock(return_value=httpx.Response(200, json=balances_data)) + + with patch("leggend.config.config", mock_config): response = api_client.get("/api/v1/accounts") - + assert response.status_code == 200 data = response.json() assert data["success"] is True @@ -61,33 +61,37 @@ class TestAccountsAPI: assert account["balances"][0]["amount"] == 100.50 @respx.mock - def test_get_account_details_success(self, api_client, mock_config, mock_auth_token, sample_account_data): + def test_get_account_details_success( + self, api_client, mock_config, mock_auth_token, sample_account_data + ): """Test successful retrieval of specific account details.""" balances_data = { "balances": [ { "balanceAmount": {"amount": "250.75", "currency": "EUR"}, - "balanceType": "interimAvailable" + "balanceType": "interimAvailable", } ] } - + # Mock GoCardless token creation respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( - return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) + return_value=httpx.Response( + 200, json={"access": "test-token", "refresh": "test-refresh"} + ) ) - + # Mock GoCardless API calls - respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/").mock( - return_value=httpx.Response(200, json=sample_account_data) - ) - respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/").mock( - return_value=httpx.Response(200, json=balances_data) - ) - - with patch('leggend.config.config', mock_config): + respx.get( + "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/" + ).mock(return_value=httpx.Response(200, json=sample_account_data)) + respx.get( + "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/" + ).mock(return_value=httpx.Response(200, json=balances_data)) + + with patch("leggend.config.config", mock_config): response = api_client.get("/api/v1/accounts/test-account-123") - + assert response.status_code == 200 data = response.json() assert data["success"] is True @@ -97,35 +101,39 @@ class TestAccountsAPI: assert len(account["balances"]) == 1 @respx.mock - def test_get_account_balances_success(self, api_client, mock_config, mock_auth_token): + def test_get_account_balances_success( + self, api_client, mock_config, mock_auth_token + ): """Test successful retrieval of account balances.""" balances_data = { "balances": [ { "balanceAmount": {"amount": "1000.00", "currency": "EUR"}, "balanceType": "interimAvailable", - "lastChangeDateTime": "2025-09-01T10:00:00Z" + "lastChangeDateTime": "2025-09-01T10:00:00Z", }, { - "balanceAmount": {"amount": "950.00", "currency": "EUR"}, - "balanceType": "expected" - } + "balanceAmount": {"amount": "950.00", "currency": "EUR"}, + "balanceType": "expected", + }, ] } - + # Mock GoCardless token creation respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( - return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) + return_value=httpx.Response( + 200, json={"access": "test-token", "refresh": "test-refresh"} + ) ) - + # Mock GoCardless API - respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/").mock( - return_value=httpx.Response(200, json=balances_data) - ) - - with patch('leggend.config.config', mock_config): + respx.get( + "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/balances/" + ).mock(return_value=httpx.Response(200, json=balances_data)) + + with patch("leggend.config.config", mock_config): response = api_client.get("/api/v1/accounts/test-account-123/balances") - + assert response.status_code == 200 data = response.json() assert data["success"] is True @@ -135,29 +143,40 @@ class TestAccountsAPI: assert data["data"][0]["balance_type"] == "interimAvailable" @respx.mock - def test_get_account_transactions_success(self, api_client, mock_config, mock_auth_token, sample_account_data, sample_transaction_data): + def test_get_account_transactions_success( + self, + api_client, + mock_config, + mock_auth_token, + sample_account_data, + sample_transaction_data, + ): """Test successful retrieval of account transactions.""" # Mock GoCardless token creation respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( - return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) + return_value=httpx.Response( + 200, json={"access": "test-token", "refresh": "test-refresh"} + ) ) - + # Mock GoCardless API calls - respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/").mock( - return_value=httpx.Response(200, json=sample_account_data) - ) - respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/transactions/").mock( - return_value=httpx.Response(200, json=sample_transaction_data) - ) - - with patch('leggend.config.config', mock_config): - response = api_client.get("/api/v1/accounts/test-account-123/transactions?summary_only=true") - + respx.get( + "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/" + ).mock(return_value=httpx.Response(200, json=sample_account_data)) + respx.get( + "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/transactions/" + ).mock(return_value=httpx.Response(200, json=sample_transaction_data)) + + with patch("leggend.config.config", mock_config): + response = api_client.get( + "/api/v1/accounts/test-account-123/transactions?summary_only=true" + ) + assert response.status_code == 200 data = response.json() assert data["success"] is True assert len(data["data"]) == 1 - + transaction = data["data"][0] assert transaction["internal_transaction_id"] == "txn-123" assert transaction["amount"] == -10.50 @@ -165,29 +184,40 @@ class TestAccountsAPI: assert transaction["description"] == "Coffee Shop Payment" @respx.mock - def test_get_account_transactions_full_details(self, api_client, mock_config, mock_auth_token, sample_account_data, sample_transaction_data): + def test_get_account_transactions_full_details( + self, + api_client, + mock_config, + mock_auth_token, + sample_account_data, + sample_transaction_data, + ): """Test retrieval of full transaction details.""" # Mock GoCardless token creation respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( - return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) + return_value=httpx.Response( + 200, json={"access": "test-token", "refresh": "test-refresh"} + ) ) - + # Mock GoCardless API calls - respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/").mock( - return_value=httpx.Response(200, json=sample_account_data) - ) - respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/transactions/").mock( - return_value=httpx.Response(200, json=sample_transaction_data) - ) - - with patch('leggend.config.config', mock_config): - response = api_client.get("/api/v1/accounts/test-account-123/transactions?summary_only=false") - + respx.get( + "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/" + ).mock(return_value=httpx.Response(200, json=sample_account_data)) + respx.get( + "https://bankaccountdata.gocardless.com/api/v2/accounts/test-account-123/transactions/" + ).mock(return_value=httpx.Response(200, json=sample_transaction_data)) + + with patch("leggend.config.config", mock_config): + response = api_client.get( + "/api/v1/accounts/test-account-123/transactions?summary_only=false" + ) + assert response.status_code == 200 data = response.json() assert data["success"] is True assert len(data["data"]) == 1 - + transaction = data["data"][0] assert transaction["internal_transaction_id"] == "txn-123" assert transaction["institution_id"] == "REVOLUT_REVOLT21" @@ -200,14 +230,18 @@ class TestAccountsAPI: with respx.mock: # Mock GoCardless token creation respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( - return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) + return_value=httpx.Response( + 200, json={"access": "test-token", "refresh": "test-refresh"} + ) ) - - respx.get("https://bankaccountdata.gocardless.com/api/v2/accounts/nonexistent/").mock( + + respx.get( + "https://bankaccountdata.gocardless.com/api/v2/accounts/nonexistent/" + ).mock( return_value=httpx.Response(404, json={"detail": "Account not found"}) ) - - with patch('leggend.config.config', mock_config): + + with patch("leggend.config.config", mock_config): response = api_client.get("/api/v1/accounts/nonexistent") - - assert response.status_code == 404 \ No newline at end of file + + assert response.status_code == 404 diff --git a/tests/unit/test_api_banks.py b/tests/unit/test_api_banks.py index 8f6cd3b..7746b47 100644 --- a/tests/unit/test_api_banks.py +++ b/tests/unit/test_api_banks.py @@ -1,4 +1,5 @@ """Tests for banks API endpoints.""" + import pytest import respx import httpx @@ -10,23 +11,27 @@ from leggend.services.gocardless_service import GoCardlessService @pytest.mark.api class TestBanksAPI: """Test bank-related API endpoints.""" - + @respx.mock - def test_get_institutions_success(self, api_client, mock_config, mock_auth_token, sample_bank_data): + def test_get_institutions_success( + self, api_client, mock_config, mock_auth_token, sample_bank_data + ): """Test successful retrieval of bank institutions.""" # Mock GoCardless token creation/refresh respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( - return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) + return_value=httpx.Response( + 200, json={"access": "test-token", "refresh": "test-refresh"} + ) ) - + # Mock GoCardless institutions API respx.get("https://bankaccountdata.gocardless.com/api/v2/institutions/").mock( return_value=httpx.Response(200, json=sample_bank_data) ) - - with patch('leggend.config.config', mock_config): + + with patch("leggend.config.config", mock_config): response = api_client.get("/api/v1/banks/institutions?country=PT") - + assert response.status_code == 200 data = response.json() assert data["success"] is True @@ -39,17 +44,19 @@ class TestBanksAPI: """Test institutions endpoint with invalid country code.""" # Mock GoCardless token creation respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( - return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) + return_value=httpx.Response( + 200, json={"access": "test-token", "refresh": "test-refresh"} + ) ) - + # Mock empty institutions response for invalid country respx.get("https://bankaccountdata.gocardless.com/api/v2/institutions/").mock( return_value=httpx.Response(200, json=[]) ) - - with patch('leggend.config.config', mock_config): + + with patch("leggend.config.config", mock_config): response = api_client.get("/api/v1/banks/institutions?country=XX") - + # Should still work but return empty or filtered results assert response.status_code in [200, 404] @@ -61,27 +68,29 @@ class TestBanksAPI: "institution_id": "REVOLUT_REVOLT21", "status": "CR", "created": "2025-09-02T00:00:00Z", - "link": "https://example.com/auth" + "link": "https://example.com/auth", } - + # Mock GoCardless token creation respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( - return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) + return_value=httpx.Response( + 200, json={"access": "test-token", "refresh": "test-refresh"} + ) ) - + # Mock GoCardless requisitions API respx.post("https://bankaccountdata.gocardless.com/api/v2/requisitions/").mock( return_value=httpx.Response(200, json=requisition_data) ) - + request_data = { "institution_id": "REVOLUT_REVOLT21", - "redirect_url": "http://localhost:8000/" + "redirect_url": "http://localhost:8000/", } - - with patch('leggend.config.config', mock_config): + + with patch("leggend.config.config", mock_config): response = api_client.post("/api/v1/banks/connect", json=request_data) - + assert response.status_code == 200 data = response.json() assert data["success"] is True @@ -95,27 +104,29 @@ class TestBanksAPI: "results": [ { "id": "req-123", - "institution_id": "REVOLUT_REVOLT21", + "institution_id": "REVOLUT_REVOLT21", "status": "LN", "created": "2025-09-02T00:00:00Z", - "accounts": ["acc-123"] + "accounts": ["acc-123"], } ] } - + # Mock GoCardless token creation respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( - return_value=httpx.Response(200, json={"access": "test-token", "refresh": "test-refresh"}) + return_value=httpx.Response( + 200, json={"access": "test-token", "refresh": "test-refresh"} + ) ) - + # Mock GoCardless requisitions API respx.get("https://bankaccountdata.gocardless.com/api/v2/requisitions/").mock( return_value=httpx.Response(200, json=requisitions_data) ) - - with patch('leggend.config.config', mock_config): + + with patch("leggend.config.config", mock_config): response = api_client.get("/api/v1/banks/status") - + assert response.status_code == 200 data = response.json() assert data["success"] is True @@ -126,12 +137,12 @@ class TestBanksAPI: def test_get_supported_countries(self, api_client): """Test supported countries endpoint.""" response = api_client.get("/api/v1/banks/countries") - + assert response.status_code == 200 data = response.json() assert data["success"] is True assert len(data["data"]) > 0 - + # Check some expected countries country_codes = [country["code"] for country in data["data"]] assert "PT" in country_codes @@ -145,10 +156,10 @@ class TestBanksAPI: respx.post("https://bankaccountdata.gocardless.com/api/v2/token/new/").mock( return_value=httpx.Response(401, json={"detail": "Invalid credentials"}) ) - - with patch('leggend.config.config', mock_config): + + with patch("leggend.config.config", mock_config): response = api_client.get("/api/v1/banks/institutions") - + assert response.status_code == 500 data = response.json() - assert "Failed to get institutions" in data["detail"] \ No newline at end of file + assert "Failed to get institutions" in data["detail"] diff --git a/tests/unit/test_api_client.py b/tests/unit/test_api_client.py index 5f68a52..75f8c0d 100644 --- a/tests/unit/test_api_client.py +++ b/tests/unit/test_api_client.py @@ -1,4 +1,5 @@ """Tests for CLI API client.""" + import pytest import requests_mock from unittest.mock import patch @@ -13,36 +14,36 @@ class TestLeggendAPIClient: def test_health_check_success(self): """Test successful health check.""" client = LeggendAPIClient("http://localhost:8000") - + with requests_mock.Mocker() as m: m.get("http://localhost:8000/health", json={"status": "healthy"}) - + result = client.health_check() assert result is True def test_health_check_failure(self): """Test health check failure.""" client = LeggendAPIClient("http://localhost:8000") - + with requests_mock.Mocker() as m: m.get("http://localhost:8000/health", status_code=500) - + result = client.health_check() assert result is False def test_get_institutions_success(self, sample_bank_data): """Test getting institutions via API client.""" client = LeggendAPIClient("http://localhost:8000") - + api_response = { "success": True, "data": sample_bank_data, - "message": "Found 2 institutions for PT" + "message": "Found 2 institutions for PT", } - + with requests_mock.Mocker() as m: m.get("http://localhost:8000/api/v1/banks/institutions", json=api_response) - + result = client.get_institutions("PT") assert len(result) == 2 assert result[0]["id"] == "REVOLUT_REVOLT21" @@ -50,16 +51,16 @@ class TestLeggendAPIClient: def test_get_accounts_success(self, sample_account_data): """Test getting accounts via API client.""" client = LeggendAPIClient("http://localhost:8000") - + api_response = { "success": True, "data": [sample_account_data], - "message": "Retrieved 1 accounts" + "message": "Retrieved 1 accounts", } - + with requests_mock.Mocker() as m: m.get("http://localhost:8000/api/v1/accounts", json=api_response) - + result = client.get_accounts() assert len(result) == 1 assert result[0]["id"] == "test-account-123" @@ -67,34 +68,37 @@ class TestLeggendAPIClient: def test_trigger_sync_success(self): """Test triggering sync via API client.""" client = LeggendAPIClient("http://localhost:8000") - + api_response = { "success": True, "data": {"sync_started": True, "force": False}, - "message": "Started sync for all accounts" + "message": "Started sync for all accounts", } - + with requests_mock.Mocker() as m: m.post("http://localhost:8000/api/v1/sync", json=api_response) - + result = client.trigger_sync() assert result["sync_started"] is True def test_connection_error_handling(self): """Test handling of connection errors.""" client = LeggendAPIClient("http://localhost:9999") # Non-existent service - + with pytest.raises(Exception): client.get_accounts() def test_http_error_handling(self): """Test handling of HTTP errors.""" client = LeggendAPIClient("http://localhost:8000") - + with requests_mock.Mocker() as m: - m.get("http://localhost:8000/api/v1/accounts", status_code=500, - json={"detail": "Internal server error"}) - + m.get( + "http://localhost:8000/api/v1/accounts", + status_code=500, + json={"detail": "Internal server error"}, + ) + with pytest.raises(Exception): client.get_accounts() @@ -102,28 +106,28 @@ class TestLeggendAPIClient: """Test using custom API URL.""" custom_url = "http://custom-host:9000" client = LeggendAPIClient(custom_url) - + assert client.base_url == custom_url def test_environment_variable_url(self): """Test using environment variable for API URL.""" - with patch.dict('os.environ', {'LEGGEND_API_URL': 'http://env-host:7000'}): + with patch.dict("os.environ", {"LEGGEND_API_URL": "http://env-host:7000"}): client = LeggendAPIClient() assert client.base_url == "http://env-host:7000" def test_sync_with_options(self): """Test sync with various options.""" client = LeggendAPIClient("http://localhost:8000") - + api_response = { "success": True, "data": {"sync_started": True, "force": True}, - "message": "Started sync for 2 specific accounts" + "message": "Started sync for 2 specific accounts", } - + with requests_mock.Mocker() as m: m.post("http://localhost:8000/api/v1/sync", json=api_response) - + result = client.trigger_sync(account_ids=["acc1", "acc2"], force=True) assert result["sync_started"] is True assert result["force"] is True @@ -131,20 +135,20 @@ class TestLeggendAPIClient: def test_get_scheduler_config(self): """Test getting scheduler configuration.""" client = LeggendAPIClient("http://localhost:8000") - + api_response = { "success": True, "data": { "enabled": True, "hour": 3, "minute": 0, - "next_scheduled_sync": "2025-09-03T03:00:00Z" - } + "next_scheduled_sync": "2025-09-03T03:00:00Z", + }, } - + with requests_mock.Mocker() as m: m.get("http://localhost:8000/api/v1/sync/scheduler", json=api_response) - + result = client.get_scheduler_config() assert result["enabled"] is True - assert result["hour"] == 3 \ No newline at end of file + assert result["hour"] == 3 diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 8749ea4..4aeecf7 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -1,4 +1,5 @@ """Tests for configuration management.""" + import pytest import tempfile from pathlib import Path @@ -23,25 +24,24 @@ class TestConfig: "gocardless": { "key": "test-key", "secret": "test-secret", - "url": "https://test.example.com" + "url": "https://test.example.com", }, - "database": { - "sqlite": True - } + "database": {"sqlite": True}, } - + config_file = temp_config_dir / "config.toml" with open(config_file, "wb") as f: import tomli_w + tomli_w.dump(config_data, f) - + config = Config() # Reset singleton state for testing config._config = None config._config_path = None - + result = config.load_config(str(config_file)) - + assert result == config_data assert config.gocardless_config["key"] == "test-key" assert config.database_config["sqlite"] is True @@ -50,87 +50,84 @@ class TestConfig: """Test handling of missing configuration file.""" config = Config() config._config = None # Reset for test - + with pytest.raises(FileNotFoundError): config.load_config("/nonexistent/config.toml") def test_save_config_success(self, temp_config_dir): """Test successful configuration saving.""" - config_data = { - "gocardless": { - "key": "new-key", - "secret": "new-secret" - } - } - + config_data = {"gocardless": {"key": "new-key", "secret": "new-secret"}} + config_file = temp_config_dir / "new_config.toml" config = Config() config._config = None - + config.save_config(config_data, str(config_file)) - + # Verify file was created and contains correct data assert config_file.exists() - + import tomllib + with open(config_file, "rb") as f: saved_data = tomllib.load(f) - + assert saved_data == config_data def test_update_config_success(self, temp_config_dir): """Test updating configuration values.""" initial_config = { "gocardless": {"key": "old-key"}, - "database": {"sqlite": True} + "database": {"sqlite": True}, } - + config_file = temp_config_dir / "config.toml" with open(config_file, "wb") as f: import tomli_w + tomli_w.dump(initial_config, f) - + config = Config() config._config = None config.load_config(str(config_file)) - + config.update_config("gocardless", "key", "new-key") - + assert config.gocardless_config["key"] == "new-key" - + # Verify it was saved to file import tomllib + with open(config_file, "rb") as f: saved_data = tomllib.load(f) assert saved_data["gocardless"]["key"] == "new-key" def test_update_section_success(self, temp_config_dir): """Test updating entire configuration section.""" - initial_config = { - "database": {"sqlite": True} - } - + initial_config = {"database": {"sqlite": True}} + config_file = temp_config_dir / "config.toml" with open(config_file, "wb") as f: import tomli_w + tomli_w.dump(initial_config, f) - + config = Config() config._config = None config.load_config(str(config_file)) - + new_db_config = {"sqlite": False, "path": "./custom.db"} config.update_section("database", new_db_config) - + assert config.database_config == new_db_config def test_scheduler_config_defaults(self): """Test scheduler configuration with defaults.""" config = Config() config._config = {} # Empty config - + scheduler_config = config.scheduler_config - + assert scheduler_config["sync"]["enabled"] is True assert scheduler_config["sync"]["hour"] == 3 assert scheduler_config["sync"]["minute"] == 0 @@ -144,16 +141,16 @@ class TestConfig: "enabled": False, "hour": 6, "minute": 30, - "cron": "0 6 * * 1-5" + "cron": "0 6 * * 1-5", } } } - + config = Config() config._config = custom_config - + scheduler_config = config.scheduler_config - + assert scheduler_config["sync"]["enabled"] is False assert scheduler_config["sync"]["hour"] == 6 assert scheduler_config["sync"]["minute"] == 30 @@ -161,26 +158,28 @@ class TestConfig: def test_environment_variable_config_path(self): """Test using environment variable for config path.""" - with patch.dict('os.environ', {'LEGGEN_CONFIG_FILE': '/custom/path/config.toml'}): + with patch.dict( + "os.environ", {"LEGGEN_CONFIG_FILE": "/custom/path/config.toml"} + ): config = Config() config._config = None - - with patch('builtins.open', side_effect=FileNotFoundError): + + with patch("builtins.open", side_effect=FileNotFoundError): with pytest.raises(FileNotFoundError): config.load_config() def test_notifications_config(self): - """Test notifications configuration access.""" + """Test notifications configuration access.""" custom_config = { "notifications": { "discord": {"webhook": "https://discord.webhook", "enabled": True}, - "telegram": {"token": "bot-token", "chat_id": 123} + "telegram": {"token": "bot-token", "chat_id": 123}, } } - + config = Config() config._config = custom_config - + notifications = config.notifications_config assert notifications["discord"]["webhook"] == "https://discord.webhook" assert notifications["telegram"]["token"] == "bot-token" @@ -190,13 +189,13 @@ class TestConfig: custom_config = { "filters": { "case-insensitive": {"salary": "SALARY", "bills": "BILL"}, - "amount_threshold": 100.0 + "amount_threshold": 100.0, } } - + config = Config() config._config = custom_config - + filters = config.filters_config assert filters["case-insensitive"]["salary"] == "SALARY" - assert filters["amount_threshold"] == 100.0 \ No newline at end of file + assert filters["amount_threshold"] == 100.0 diff --git a/tests/unit/test_scheduler.py b/tests/unit/test_scheduler.py index 3206aae..a32c933 100644 --- a/tests/unit/test_scheduler.py +++ b/tests/unit/test_scheduler.py @@ -1,4 +1,5 @@ """Tests for background scheduler.""" + import pytest import asyncio from unittest.mock import Mock, patch, AsyncMock, MagicMock @@ -16,23 +17,19 @@ class TestBackgroundScheduler: @pytest.fixture def mock_config(self): """Mock configuration for scheduler tests.""" - return { - "sync": { - "enabled": True, - "hour": 3, - "minute": 0, - "cron": None - } - } + return {"sync": {"enabled": True, "hour": 3, "minute": 0, "cron": None}} @pytest.fixture def scheduler(self): """Create scheduler instance for testing.""" - with patch('leggend.background.scheduler.SyncService'), \ - patch('leggend.background.scheduler.config') as mock_config: - - mock_config.scheduler_config = {"sync": {"enabled": True, "hour": 3, "minute": 0}} - + with ( + patch("leggend.background.scheduler.SyncService"), + patch("leggend.background.scheduler.config") as mock_config, + ): + mock_config.scheduler_config = { + "sync": {"enabled": True, "hour": 3, "minute": 0} + } + # Create scheduler and replace its AsyncIO scheduler with a mock scheduler = BackgroundScheduler() mock_scheduler = MagicMock() @@ -43,35 +40,34 @@ class TestBackgroundScheduler: def test_scheduler_start_default_config(self, scheduler, mock_config): """Test starting scheduler with default configuration.""" - with patch('leggend.config.config') as mock_config_obj: + with patch("leggend.config.config") as mock_config_obj: mock_config_obj.scheduler_config = mock_config - + # Mock the job that gets added mock_job = MagicMock() mock_job.id = "daily_sync" scheduler.scheduler.get_jobs.return_value = [mock_job] - + scheduler.start() - + # Verify scheduler.start() was called scheduler.scheduler.start.assert_called_once() - # Verify add_job was called + # Verify add_job was called scheduler.scheduler.add_job.assert_called_once() def test_scheduler_start_disabled(self, scheduler): """Test scheduler behavior when sync is disabled.""" - disabled_config = { - "sync": {"enabled": False} - } - - with patch.object(scheduler, 'scheduler') as mock_scheduler, \ - patch('leggend.background.scheduler.config') as mock_config_obj: - + disabled_config = {"sync": {"enabled": False}} + + with ( + patch.object(scheduler, "scheduler") as mock_scheduler, + patch("leggend.background.scheduler.config") as mock_config_obj, + ): mock_config_obj.scheduler_config = disabled_config mock_scheduler.running = False - + scheduler.start() - + # Verify scheduler.start() was called mock_scheduler.start.assert_called_once() # Verify add_job was NOT called for disabled sync @@ -82,39 +78,35 @@ class TestBackgroundScheduler: cron_config = { "sync": { "enabled": True, - "cron": "0 6 * * 1-5" # 6 AM on weekdays + "cron": "0 6 * * 1-5", # 6 AM on weekdays } } - - with patch('leggend.config.config') as mock_config_obj: + + with patch("leggend.config.config") as mock_config_obj: mock_config_obj.scheduler_config = cron_config - + scheduler.start() - + # Verify scheduler.start() and add_job were called scheduler.scheduler.start.assert_called_once() scheduler.scheduler.add_job.assert_called_once() # Verify job was added with correct ID call_args = scheduler.scheduler.add_job.call_args - assert call_args.kwargs['id'] == 'daily_sync' + assert call_args.kwargs["id"] == "daily_sync" def test_scheduler_start_invalid_cron(self, scheduler): """Test handling of invalid cron expressions.""" - invalid_cron_config = { - "sync": { - "enabled": True, - "cron": "invalid cron" - } - } - - with patch.object(scheduler, 'scheduler') as mock_scheduler, \ - patch('leggend.background.scheduler.config') as mock_config_obj: - + invalid_cron_config = {"sync": {"enabled": True, "cron": "invalid cron"}} + + with ( + patch.object(scheduler, "scheduler") as mock_scheduler, + patch("leggend.background.scheduler.config") as mock_config_obj, + ): mock_config_obj.scheduler_config = invalid_cron_config mock_scheduler.running = False - + scheduler.start() - + # With invalid cron, scheduler.start() should not be called due to early return # and add_job should not be called mock_scheduler.start.assert_not_called() @@ -123,24 +115,20 @@ class TestBackgroundScheduler: def test_scheduler_shutdown(self, scheduler): """Test scheduler shutdown.""" scheduler.scheduler.running = True - + scheduler.shutdown() - + scheduler.scheduler.shutdown.assert_called_once() def test_reschedule_sync(self, scheduler, mock_config): """Test rescheduling sync job.""" scheduler.scheduler.running = True - + # Reschedule with new config - new_config = { - "enabled": True, - "hour": 6, - "minute": 30 - } - + new_config = {"enabled": True, "hour": 6, "minute": 30} + scheduler.reschedule_sync(new_config) - + # Verify remove_job and add_job were called scheduler.scheduler.remove_job.assert_called_once_with("daily_sync") scheduler.scheduler.add_job.assert_called_once() @@ -148,11 +136,11 @@ class TestBackgroundScheduler: def test_reschedule_sync_disable(self, scheduler, mock_config): """Test disabling sync via reschedule.""" scheduler.scheduler.running = True - + # Disable sync disabled_config = {"enabled": False} scheduler.reschedule_sync(disabled_config) - + # Job should be removed but not re-added scheduler.scheduler.remove_job.assert_called_once_with("daily_sync") scheduler.scheduler.add_job.assert_not_called() @@ -162,9 +150,9 @@ class TestBackgroundScheduler: mock_job = MagicMock() mock_job.next_run_time = datetime(2025, 9, 2, 3, 0) scheduler.scheduler.get_job.return_value = mock_job - + next_time = scheduler.get_next_sync_time() - + assert next_time is not None assert isinstance(next_time, datetime) scheduler.scheduler.get_job.assert_called_once_with("daily_sync") @@ -172,9 +160,9 @@ class TestBackgroundScheduler: def test_get_next_sync_time_no_job(self, scheduler): """Test getting next sync time when no job is scheduled.""" scheduler.scheduler.get_job.return_value = None - + next_time = scheduler.get_next_sync_time() - + assert next_time is None scheduler.scheduler.get_job.assert_called_once_with("daily_sync") @@ -183,9 +171,9 @@ class TestBackgroundScheduler: """Test successful sync job execution.""" mock_sync_service = AsyncMock() scheduler.sync_service = mock_sync_service - + await scheduler._run_sync() - + mock_sync_service.sync_all_accounts.assert_called_once() @pytest.mark.asyncio @@ -194,18 +182,18 @@ class TestBackgroundScheduler: mock_sync_service = AsyncMock() mock_sync_service.sync_all_accounts.side_effect = Exception("Sync failed") scheduler.sync_service = mock_sync_service - + # Should not raise exception, just log error await scheduler._run_sync() - + mock_sync_service.sync_all_accounts.assert_called_once() def test_scheduler_job_max_instances(self, scheduler, mock_config): """Test that sync jobs have max_instances=1.""" - with patch('leggend.config.config') as mock_config_obj: + with patch("leggend.config.config") as mock_config_obj: mock_config_obj.scheduler_config = mock_config scheduler.start() - + # Verify add_job was called with max_instances=1 call_args = scheduler.scheduler.add_job.call_args - assert call_args.kwargs['max_instances'] == 1 \ No newline at end of file + assert call_args.kwargs["max_instances"] == 1