mirror of
https://github.com/elisiariocouto/leggen.git
synced 2025-12-25 00:19:37 +00:00
refactor: Unify leggen and leggend packages into single leggen package
- Merge leggend API components into leggen (api/, services/, background/) - Replace leggend command with 'leggen server' subcommand - Consolidate configuration systems into leggen.utils.config - Update environment variables: LEGGEND_API_URL -> LEGGEN_API_URL - Rename LeggendAPIClient -> LeggenAPIClient - Update all documentation, Docker configs, and compose files - Fix all import statements and test references - Remove duplicate utility files and clean up package structure All tests passing (101/101), linting clean, server functionality preserved. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
committed by
Elisiário Couto
parent
0e645d9bae
commit
318ca517f7
77
leggen/api/models/accounts.py
Normal file
77
leggen/api/models/accounts.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AccountBalance(BaseModel):
|
||||
"""Account balance model"""
|
||||
|
||||
amount: float
|
||||
currency: str
|
||||
balance_type: str
|
||||
last_change_date: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
json_encoders = {datetime: lambda v: v.isoformat() if v else None}
|
||||
|
||||
|
||||
class AccountDetails(BaseModel):
|
||||
"""Account details model"""
|
||||
|
||||
id: str
|
||||
institution_id: str
|
||||
status: str
|
||||
iban: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
currency: Optional[str] = None
|
||||
created: datetime
|
||||
last_accessed: Optional[datetime] = None
|
||||
balances: List[AccountBalance] = []
|
||||
|
||||
class Config:
|
||||
json_encoders = {datetime: lambda v: v.isoformat() if v else None}
|
||||
|
||||
|
||||
class AccountUpdate(BaseModel):
|
||||
"""Account update model"""
|
||||
|
||||
name: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
json_encoders = {datetime: lambda v: v.isoformat() if v else None}
|
||||
|
||||
|
||||
class Transaction(BaseModel):
|
||||
"""Transaction model"""
|
||||
|
||||
transaction_id: str # NEW: stable bank-provided transaction ID
|
||||
internal_transaction_id: Optional[str] = None # OLD: unstable GoCardless ID
|
||||
institution_id: str
|
||||
iban: Optional[str] = None
|
||||
account_id: str
|
||||
transaction_date: datetime
|
||||
description: str
|
||||
transaction_value: float
|
||||
transaction_currency: str
|
||||
transaction_status: str # "booked" or "pending"
|
||||
raw_transaction: Dict[str, Any]
|
||||
|
||||
class Config:
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
|
||||
|
||||
class TransactionSummary(BaseModel):
|
||||
"""Transaction summary for lists"""
|
||||
|
||||
transaction_id: str # NEW: stable bank-provided transaction ID
|
||||
internal_transaction_id: Optional[str] = None
|
||||
date: datetime
|
||||
description: str
|
||||
amount: float
|
||||
currency: str
|
||||
status: str
|
||||
account_id: str
|
||||
|
||||
class Config:
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
52
leggen/api/models/banks.py
Normal file
52
leggen/api/models/banks.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BankInstitution(BaseModel):
|
||||
"""Bank institution model"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
bic: Optional[str] = None
|
||||
transaction_total_days: int
|
||||
countries: List[str]
|
||||
logo: Optional[str] = None
|
||||
|
||||
|
||||
class BankConnectionRequest(BaseModel):
|
||||
"""Request to connect to a bank"""
|
||||
|
||||
institution_id: str
|
||||
redirect_url: Optional[str] = "http://localhost:8000/"
|
||||
|
||||
|
||||
class BankRequisition(BaseModel):
|
||||
"""Bank requisition/connection model"""
|
||||
|
||||
id: str
|
||||
institution_id: str
|
||||
status: str
|
||||
status_display: Optional[str] = None
|
||||
created: datetime
|
||||
link: str
|
||||
accounts: List[str] = []
|
||||
|
||||
class Config:
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
|
||||
|
||||
class BankConnectionStatus(BaseModel):
|
||||
"""Bank connection status response"""
|
||||
|
||||
bank_id: str
|
||||
bank_name: str
|
||||
status: str
|
||||
status_display: str
|
||||
created_at: datetime
|
||||
requisition_id: str
|
||||
accounts_count: int
|
||||
|
||||
class Config:
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
29
leggen/api/models/common.py
Normal file
29
leggen/api/models/common.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class APIResponse(BaseModel):
|
||||
"""Base API response model"""
|
||||
|
||||
success: bool = True
|
||||
message: Optional[str] = None
|
||||
data: Optional[Any] = None
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Error response model"""
|
||||
|
||||
success: bool = False
|
||||
message: str
|
||||
error_code: Optional[str] = None
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel):
|
||||
"""Paginated response model"""
|
||||
|
||||
success: bool = True
|
||||
data: list
|
||||
pagination: Dict[str, Any]
|
||||
message: Optional[str] = None
|
||||
51
leggen/api/models/notifications.py
Normal file
51
leggen/api/models/notifications.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from typing import Optional, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DiscordConfig(BaseModel):
|
||||
"""Discord notification configuration"""
|
||||
|
||||
webhook: str
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class TelegramConfig(BaseModel):
|
||||
"""Telegram notification configuration"""
|
||||
|
||||
token: str
|
||||
chat_id: int
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class NotificationFilters(BaseModel):
|
||||
"""Notification filters configuration"""
|
||||
|
||||
case_insensitive: List[str] = []
|
||||
case_sensitive: Optional[List[str]] = None
|
||||
|
||||
|
||||
class NotificationSettings(BaseModel):
|
||||
"""Complete notification settings"""
|
||||
|
||||
discord: Optional[DiscordConfig] = None
|
||||
telegram: Optional[TelegramConfig] = None
|
||||
filters: NotificationFilters = NotificationFilters()
|
||||
|
||||
|
||||
class NotificationTest(BaseModel):
|
||||
"""Test notification request"""
|
||||
|
||||
service: str # "discord" or "telegram"
|
||||
message: str = "Test notification from Leggen"
|
||||
|
||||
|
||||
class NotificationHistory(BaseModel):
|
||||
"""Notification history entry"""
|
||||
|
||||
id: str
|
||||
service: str
|
||||
message: str
|
||||
status: str # "sent", "failed"
|
||||
sent_at: str
|
||||
error: Optional[str] = None
|
||||
55
leggen/api/models/sync.py
Normal file
55
leggen/api/models/sync.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SyncRequest(BaseModel):
|
||||
"""Request to trigger a sync"""
|
||||
|
||||
account_ids: Optional[list[str]] = None # If None, sync all accounts
|
||||
force: bool = False # Force sync even if recently synced
|
||||
|
||||
|
||||
class SyncStatus(BaseModel):
|
||||
"""Sync operation status"""
|
||||
|
||||
is_running: bool
|
||||
last_sync: Optional[datetime] = None
|
||||
next_sync: Optional[datetime] = None
|
||||
accounts_synced: int = 0
|
||||
total_accounts: int = 0
|
||||
transactions_added: int = 0
|
||||
errors: list[str] = []
|
||||
|
||||
class Config:
|
||||
json_encoders = {datetime: lambda v: v.isoformat() if v else None}
|
||||
|
||||
|
||||
class SyncResult(BaseModel):
|
||||
"""Result of a sync operation"""
|
||||
|
||||
success: bool
|
||||
accounts_processed: int
|
||||
transactions_added: int
|
||||
transactions_updated: int
|
||||
balances_updated: int
|
||||
duration_seconds: float
|
||||
errors: list[str] = []
|
||||
started_at: datetime
|
||||
completed_at: datetime
|
||||
|
||||
class Config:
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
|
||||
|
||||
class SchedulerConfig(BaseModel):
|
||||
"""Scheduler configuration model"""
|
||||
|
||||
enabled: bool = True
|
||||
hour: Optional[int] = 3
|
||||
minute: Optional[int] = 0
|
||||
cron: Optional[str] = None # Custom cron expression
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
356
leggen/api/routes/accounts.py
Normal file
356
leggen/api/routes/accounts.py
Normal file
@@ -0,0 +1,356 @@
|
||||
from typing import Optional, List, Union
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from loguru import logger
|
||||
|
||||
from leggen.api.models.common import APIResponse
|
||||
from leggen.api.models.accounts import (
|
||||
AccountDetails,
|
||||
AccountBalance,
|
||||
Transaction,
|
||||
TransactionSummary,
|
||||
AccountUpdate,
|
||||
)
|
||||
from leggen.services.database_service import DatabaseService
|
||||
|
||||
router = APIRouter()
|
||||
database_service = DatabaseService()
|
||||
|
||||
|
||||
@router.get("/accounts", response_model=APIResponse)
|
||||
async def get_all_accounts() -> APIResponse:
|
||||
"""Get all connected accounts from database"""
|
||||
try:
|
||||
accounts = []
|
||||
|
||||
# Get all account details from database
|
||||
db_accounts = await database_service.get_accounts_from_db()
|
||||
|
||||
# Process accounts found in database
|
||||
for db_account in db_accounts:
|
||||
try:
|
||||
# Get latest balances from database for this account
|
||||
balances_data = await database_service.get_balances_from_db(
|
||||
db_account["id"]
|
||||
)
|
||||
|
||||
# Process balances
|
||||
balances = []
|
||||
for balance in balances_data:
|
||||
balances.append(
|
||||
AccountBalance(
|
||||
amount=balance["amount"],
|
||||
currency=balance["currency"],
|
||||
balance_type=balance["type"],
|
||||
last_change_date=balance.get("timestamp"),
|
||||
)
|
||||
)
|
||||
|
||||
accounts.append(
|
||||
AccountDetails(
|
||||
id=db_account["id"],
|
||||
institution_id=db_account["institution_id"],
|
||||
status=db_account["status"],
|
||||
iban=db_account.get("iban"),
|
||||
name=db_account.get("name"),
|
||||
currency=db_account.get("currency"),
|
||||
created=db_account["created"],
|
||||
last_accessed=db_account.get("last_accessed"),
|
||||
balances=balances,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to process database account {db_account['id']}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=accounts,
|
||||
message=f"Retrieved {len(accounts)} accounts from database",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get accounts: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get accounts: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/accounts/{account_id}", response_model=APIResponse)
|
||||
async def get_account_details(account_id: str) -> APIResponse:
|
||||
"""Get details for a specific account from database"""
|
||||
try:
|
||||
# Get account details from database
|
||||
db_account = await database_service.get_account_details_from_db(account_id)
|
||||
|
||||
if not db_account:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Account {account_id} not found in database"
|
||||
)
|
||||
|
||||
# Get latest balances from database for this account
|
||||
balances_data = await database_service.get_balances_from_db(account_id)
|
||||
|
||||
# Process balances
|
||||
balances = []
|
||||
for balance in balances_data:
|
||||
balances.append(
|
||||
AccountBalance(
|
||||
amount=balance["amount"],
|
||||
currency=balance["currency"],
|
||||
balance_type=balance["type"],
|
||||
last_change_date=balance.get("timestamp"),
|
||||
)
|
||||
)
|
||||
|
||||
account = AccountDetails(
|
||||
id=db_account["id"],
|
||||
institution_id=db_account["institution_id"],
|
||||
status=db_account["status"],
|
||||
iban=db_account.get("iban"),
|
||||
name=db_account.get("name"),
|
||||
currency=db_account.get("currency"),
|
||||
created=db_account["created"],
|
||||
last_accessed=db_account.get("last_accessed"),
|
||||
balances=balances,
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=account,
|
||||
message=f"Account details retrieved from database for {account_id}",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get account details for {account_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get account details: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/accounts/{account_id}/balances", response_model=APIResponse)
|
||||
async def get_account_balances(account_id: str) -> APIResponse:
|
||||
"""Get balances for a specific account from database"""
|
||||
try:
|
||||
# Get balances from database instead of GoCardless API
|
||||
db_balances = await database_service.get_balances_from_db(account_id=account_id)
|
||||
|
||||
balances = []
|
||||
for balance in db_balances:
|
||||
balances.append(
|
||||
AccountBalance(
|
||||
amount=balance["amount"],
|
||||
currency=balance["currency"],
|
||||
balance_type=balance["type"],
|
||||
last_change_date=balance.get("timestamp"),
|
||||
)
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=balances,
|
||||
message=f"Retrieved {len(balances)} balances for account {account_id}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get balances from database for account {account_id}: {e}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Failed to get balances: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/balances", response_model=APIResponse)
|
||||
async def get_all_balances() -> APIResponse:
|
||||
"""Get all balances from all accounts in database"""
|
||||
try:
|
||||
# Get all accounts first to iterate through them
|
||||
db_accounts = await database_service.get_accounts_from_db()
|
||||
|
||||
all_balances = []
|
||||
for db_account in db_accounts:
|
||||
try:
|
||||
# Get balances for this account
|
||||
db_balances = await database_service.get_balances_from_db(
|
||||
account_id=db_account["id"]
|
||||
)
|
||||
|
||||
# Process balances and add account info
|
||||
for balance in db_balances:
|
||||
balance_data = {
|
||||
"id": f"{db_account['id']}_{balance['type']}", # Create unique ID
|
||||
"account_id": db_account["id"],
|
||||
"balance_amount": balance["amount"],
|
||||
"balance_type": balance["type"],
|
||||
"currency": balance["currency"],
|
||||
"reference_date": balance.get(
|
||||
"timestamp", db_account.get("last_accessed")
|
||||
),
|
||||
"created_at": db_account.get("created"),
|
||||
"updated_at": db_account.get("last_accessed"),
|
||||
}
|
||||
all_balances.append(balance_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get balances for account {db_account['id']}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=all_balances,
|
||||
message=f"Retrieved {len(all_balances)} balances from {len(db_accounts)} accounts",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get all balances: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get balances: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/balances/history", response_model=APIResponse)
|
||||
async def get_historical_balances(
|
||||
days: Optional[int] = Query(
|
||||
default=365, le=1095, ge=1, description="Number of days of history to retrieve"
|
||||
),
|
||||
account_id: Optional[str] = Query(
|
||||
default=None, description="Filter by specific account ID"
|
||||
),
|
||||
) -> APIResponse:
|
||||
"""Get historical balance progression calculated from transaction history"""
|
||||
try:
|
||||
# Get historical balances from database
|
||||
historical_balances = await database_service.get_historical_balances_from_db(
|
||||
account_id=account_id, days=days or 365
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=historical_balances,
|
||||
message=f"Retrieved {len(historical_balances)} historical balance points over {days} days",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get historical balances: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get historical balances: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/accounts/{account_id}/transactions", response_model=APIResponse)
|
||||
async def get_account_transactions(
|
||||
account_id: str,
|
||||
limit: Optional[int] = Query(default=100, le=500),
|
||||
offset: Optional[int] = Query(default=0, ge=0),
|
||||
summary_only: bool = Query(
|
||||
default=False, description="Return transaction summaries only"
|
||||
),
|
||||
) -> APIResponse:
|
||||
"""Get transactions for a specific account from database"""
|
||||
try:
|
||||
# Get transactions from database instead of GoCardless API
|
||||
db_transactions = await database_service.get_transactions_from_db(
|
||||
account_id=account_id,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
# Get total count for pagination info
|
||||
total_transactions = await database_service.get_transaction_count_from_db(
|
||||
account_id=account_id,
|
||||
)
|
||||
|
||||
data: Union[List[TransactionSummary], List[Transaction]]
|
||||
|
||||
if summary_only:
|
||||
# Return simplified transaction summaries
|
||||
data = [
|
||||
TransactionSummary(
|
||||
transaction_id=txn["transactionId"], # NEW: stable bank-provided ID
|
||||
internal_transaction_id=txn.get("internalTransactionId"),
|
||||
date=txn["transactionDate"],
|
||||
description=txn["description"],
|
||||
amount=txn["transactionValue"],
|
||||
currency=txn["transactionCurrency"],
|
||||
status=txn["transactionStatus"],
|
||||
account_id=txn["accountId"],
|
||||
)
|
||||
for txn in db_transactions
|
||||
]
|
||||
else:
|
||||
# Return full transaction details
|
||||
data = [
|
||||
Transaction(
|
||||
transaction_id=txn["transactionId"], # NEW: stable bank-provided ID
|
||||
internal_transaction_id=txn.get("internalTransactionId"),
|
||||
institution_id=txn["institutionId"],
|
||||
iban=txn["iban"],
|
||||
account_id=txn["accountId"],
|
||||
transaction_date=txn["transactionDate"],
|
||||
description=txn["description"],
|
||||
transaction_value=txn["transactionValue"],
|
||||
transaction_currency=txn["transactionCurrency"],
|
||||
transaction_status=txn["transactionStatus"],
|
||||
raw_transaction=txn["rawTransaction"],
|
||||
)
|
||||
for txn in db_transactions
|
||||
]
|
||||
|
||||
actual_offset = offset or 0
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=data,
|
||||
message=f"Retrieved {len(data)} transactions (showing {actual_offset + 1}-{actual_offset + len(data)} of {total_transactions})",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get transactions from database for account {account_id}: {e}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Failed to get transactions: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put("/accounts/{account_id}", response_model=APIResponse)
|
||||
async def update_account_details(
|
||||
account_id: str, update_data: AccountUpdate
|
||||
) -> APIResponse:
|
||||
"""Update account details (currently only name)"""
|
||||
try:
|
||||
# Get current account details
|
||||
current_account = await database_service.get_account_details_from_db(account_id)
|
||||
|
||||
if not current_account:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Account {account_id} not found"
|
||||
)
|
||||
|
||||
# Prepare updated account data
|
||||
updated_account_data = current_account.copy()
|
||||
if update_data.name is not None:
|
||||
updated_account_data["name"] = update_data.name
|
||||
|
||||
# Persist updated account details
|
||||
await database_service.persist_account_details(updated_account_data)
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data={"id": account_id, "name": update_data.name},
|
||||
message=f"Account {account_id} name updated successfully",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update account {account_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to update account: {str(e)}"
|
||||
) from e
|
||||
179
leggen/api/routes/banks.py
Normal file
179
leggen/api/routes/banks.py
Normal file
@@ -0,0 +1,179 @@
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from loguru import logger
|
||||
|
||||
from leggen.api.models.common import APIResponse
|
||||
from leggen.api.models.banks import (
|
||||
BankInstitution,
|
||||
BankConnectionRequest,
|
||||
BankRequisition,
|
||||
BankConnectionStatus,
|
||||
)
|
||||
from leggen.services.gocardless_service import GoCardlessService
|
||||
from leggen.utils.gocardless import REQUISITION_STATUS
|
||||
|
||||
router = APIRouter()
|
||||
gocardless_service = GoCardlessService()
|
||||
|
||||
|
||||
@router.get("/banks/institutions", response_model=APIResponse)
|
||||
async def get_bank_institutions(
|
||||
country: str = Query(default="PT", description="Country code (e.g., PT, ES, FR)"),
|
||||
) -> APIResponse:
|
||||
"""Get available bank institutions for a country"""
|
||||
try:
|
||||
institutions_data = await gocardless_service.get_institutions(country)
|
||||
|
||||
institutions = [
|
||||
BankInstitution(
|
||||
id=inst["id"],
|
||||
name=inst["name"],
|
||||
bic=inst.get("bic"),
|
||||
transaction_total_days=inst["transaction_total_days"],
|
||||
countries=inst["countries"],
|
||||
logo=inst.get("logo"),
|
||||
)
|
||||
for inst in institutions_data
|
||||
]
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=institutions,
|
||||
message=f"Found {len(institutions)} institutions for {country}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get institutions for {country}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get institutions: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/banks/connect", response_model=APIResponse)
|
||||
async def connect_to_bank(request: BankConnectionRequest) -> APIResponse:
|
||||
"""Create a connection to a bank (requisition)"""
|
||||
try:
|
||||
redirect_url = request.redirect_url or "http://localhost:8000/"
|
||||
requisition_data = await gocardless_service.create_requisition(
|
||||
request.institution_id, redirect_url
|
||||
)
|
||||
|
||||
requisition = BankRequisition(
|
||||
id=requisition_data["id"],
|
||||
institution_id=requisition_data["institution_id"],
|
||||
status=requisition_data["status"],
|
||||
created=requisition_data["created"],
|
||||
link=requisition_data["link"],
|
||||
accounts=requisition_data.get("accounts", []),
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=requisition,
|
||||
message="Bank connection created. Please visit the link to authorize.",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to bank {request.institution_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to connect to bank: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/banks/status", response_model=APIResponse)
|
||||
async def get_bank_connections_status() -> APIResponse:
|
||||
"""Get status of all bank connections"""
|
||||
try:
|
||||
requisitions_data = await gocardless_service.get_requisitions()
|
||||
|
||||
connections = []
|
||||
for req in requisitions_data.get("results", []):
|
||||
status = req["status"]
|
||||
status_display = REQUISITION_STATUS.get(status, "UNKNOWN")
|
||||
|
||||
connections.append(
|
||||
BankConnectionStatus(
|
||||
bank_id=req["institution_id"],
|
||||
bank_name=req[
|
||||
"institution_id"
|
||||
], # Could be enhanced with actual bank names
|
||||
status=status,
|
||||
status_display=status_display,
|
||||
created_at=req["created"],
|
||||
requisition_id=req["id"],
|
||||
accounts_count=len(req.get("accounts", [])),
|
||||
)
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=connections,
|
||||
message=f"Found {len(connections)} bank connections",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get bank connection status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get bank status: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/banks/connections/{requisition_id}", response_model=APIResponse)
|
||||
async def delete_bank_connection(requisition_id: str) -> APIResponse:
|
||||
"""Delete a bank connection"""
|
||||
try:
|
||||
# This would need to be implemented in GoCardlessService
|
||||
# For now, return success
|
||||
return APIResponse(
|
||||
success=True,
|
||||
message=f"Bank connection {requisition_id} deleted successfully",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete bank connection {requisition_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to delete connection: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/banks/countries", response_model=APIResponse)
|
||||
async def get_supported_countries() -> APIResponse:
|
||||
"""Get list of supported countries"""
|
||||
countries = [
|
||||
{"code": "AT", "name": "Austria"},
|
||||
{"code": "BE", "name": "Belgium"},
|
||||
{"code": "BG", "name": "Bulgaria"},
|
||||
{"code": "HR", "name": "Croatia"},
|
||||
{"code": "CY", "name": "Cyprus"},
|
||||
{"code": "CZ", "name": "Czech Republic"},
|
||||
{"code": "DK", "name": "Denmark"},
|
||||
{"code": "EE", "name": "Estonia"},
|
||||
{"code": "FI", "name": "Finland"},
|
||||
{"code": "FR", "name": "France"},
|
||||
{"code": "DE", "name": "Germany"},
|
||||
{"code": "GR", "name": "Greece"},
|
||||
{"code": "HU", "name": "Hungary"},
|
||||
{"code": "IS", "name": "Iceland"},
|
||||
{"code": "IE", "name": "Ireland"},
|
||||
{"code": "IT", "name": "Italy"},
|
||||
{"code": "LV", "name": "Latvia"},
|
||||
{"code": "LI", "name": "Liechtenstein"},
|
||||
{"code": "LT", "name": "Lithuania"},
|
||||
{"code": "LU", "name": "Luxembourg"},
|
||||
{"code": "MT", "name": "Malta"},
|
||||
{"code": "NL", "name": "Netherlands"},
|
||||
{"code": "NO", "name": "Norway"},
|
||||
{"code": "PL", "name": "Poland"},
|
||||
{"code": "PT", "name": "Portugal"},
|
||||
{"code": "RO", "name": "Romania"},
|
||||
{"code": "SK", "name": "Slovakia"},
|
||||
{"code": "SI", "name": "Slovenia"},
|
||||
{"code": "ES", "name": "Spain"},
|
||||
{"code": "SE", "name": "Sweden"},
|
||||
{"code": "GB", "name": "United Kingdom"},
|
||||
]
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=countries,
|
||||
message="Supported countries retrieved successfully",
|
||||
)
|
||||
203
leggen/api/routes/notifications.py
Normal file
203
leggen/api/routes/notifications.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from leggen.api.models.common import APIResponse
|
||||
from leggen.api.models.notifications import (
|
||||
NotificationSettings,
|
||||
NotificationTest,
|
||||
DiscordConfig,
|
||||
TelegramConfig,
|
||||
NotificationFilters,
|
||||
)
|
||||
from leggen.services.notification_service import NotificationService
|
||||
from leggen.utils.config import config
|
||||
|
||||
router = APIRouter()
|
||||
notification_service = NotificationService()
|
||||
|
||||
|
||||
@router.get("/notifications/settings", response_model=APIResponse)
|
||||
async def get_notification_settings() -> APIResponse:
|
||||
"""Get current notification settings"""
|
||||
try:
|
||||
notifications_config = config.notifications_config
|
||||
filters_config = config.filters_config
|
||||
|
||||
# Build response safely without exposing secrets
|
||||
discord_config = notifications_config.get("discord", {})
|
||||
telegram_config = notifications_config.get("telegram", {})
|
||||
|
||||
settings = NotificationSettings(
|
||||
discord=DiscordConfig(
|
||||
webhook="***" if discord_config.get("webhook") else "",
|
||||
enabled=discord_config.get("enabled", True),
|
||||
)
|
||||
if discord_config.get("webhook")
|
||||
else None,
|
||||
telegram=TelegramConfig(
|
||||
token="***" if telegram_config.get("api-key") else "",
|
||||
chat_id=telegram_config.get("chat-id", 0),
|
||||
enabled=telegram_config.get("enabled", True),
|
||||
)
|
||||
if telegram_config.get("api-key")
|
||||
else None,
|
||||
filters=NotificationFilters(
|
||||
case_insensitive=filters_config.get("case-insensitive", []),
|
||||
case_sensitive=filters_config.get("case-sensitive"),
|
||||
),
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=settings,
|
||||
message="Notification settings retrieved successfully",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get notification settings: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get notification settings: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put("/notifications/settings", response_model=APIResponse)
|
||||
async def update_notification_settings(settings: NotificationSettings) -> APIResponse:
|
||||
"""Update notification settings"""
|
||||
try:
|
||||
# Update notifications config
|
||||
notifications_config = {}
|
||||
|
||||
if settings.discord:
|
||||
notifications_config["discord"] = {
|
||||
"webhook": settings.discord.webhook,
|
||||
"enabled": settings.discord.enabled,
|
||||
}
|
||||
|
||||
if settings.telegram:
|
||||
notifications_config["telegram"] = {
|
||||
"api-key": settings.telegram.token,
|
||||
"chat-id": settings.telegram.chat_id,
|
||||
"enabled": settings.telegram.enabled,
|
||||
}
|
||||
|
||||
# Update filters config
|
||||
filters_config: Dict[str, Any] = {}
|
||||
if settings.filters.case_insensitive:
|
||||
filters_config["case-insensitive"] = settings.filters.case_insensitive
|
||||
if settings.filters.case_sensitive:
|
||||
filters_config["case-sensitive"] = settings.filters.case_sensitive
|
||||
|
||||
# Save to config
|
||||
if notifications_config:
|
||||
config.update_section("notifications", notifications_config)
|
||||
if filters_config:
|
||||
config.update_section("filters", filters_config)
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data={"updated": True},
|
||||
message="Notification settings updated successfully",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update notification settings: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to update notification settings: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/notifications/test", response_model=APIResponse)
|
||||
async def test_notification(test_request: NotificationTest) -> APIResponse:
|
||||
"""Send a test notification"""
|
||||
try:
|
||||
success = await notification_service.send_test_notification(
|
||||
test_request.service, test_request.message
|
||||
)
|
||||
|
||||
if success:
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data={"sent": True},
|
||||
message=f"Test notification sent to {test_request.service} successfully",
|
||||
)
|
||||
else:
|
||||
return APIResponse(
|
||||
success=False,
|
||||
message=f"Failed to send test notification to {test_request.service}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send test notification: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to send test notification: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/notifications/services", response_model=APIResponse)
|
||||
async def get_notification_services() -> APIResponse:
|
||||
"""Get available notification services and their status"""
|
||||
try:
|
||||
notifications_config = config.notifications_config
|
||||
|
||||
services = {
|
||||
"discord": {
|
||||
"name": "Discord",
|
||||
"enabled": bool(notifications_config.get("discord", {}).get("webhook")),
|
||||
"configured": bool(
|
||||
notifications_config.get("discord", {}).get("webhook")
|
||||
),
|
||||
"active": notifications_config.get("discord", {}).get("enabled", True),
|
||||
},
|
||||
"telegram": {
|
||||
"name": "Telegram",
|
||||
"enabled": bool(
|
||||
notifications_config.get("telegram", {}).get("api-key")
|
||||
and notifications_config.get("telegram", {}).get("chat-id")
|
||||
),
|
||||
"configured": bool(
|
||||
notifications_config.get("telegram", {}).get("api-key")
|
||||
and notifications_config.get("telegram", {}).get("chat-id")
|
||||
),
|
||||
"active": notifications_config.get("telegram", {}).get("enabled", True),
|
||||
},
|
||||
}
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=services,
|
||||
message="Notification services status retrieved successfully",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get notification services: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get notification services: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/notifications/settings/{service}", response_model=APIResponse)
|
||||
async def delete_notification_service(service: str) -> APIResponse:
|
||||
"""Delete/disable a notification service"""
|
||||
try:
|
||||
if service not in ["discord", "telegram"]:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Service must be 'discord' or 'telegram'"
|
||||
)
|
||||
|
||||
notifications_config = config.notifications_config.copy()
|
||||
if service in notifications_config:
|
||||
del notifications_config[service]
|
||||
config.update_section("notifications", notifications_config)
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data={"deleted": service},
|
||||
message=f"{service.capitalize()} notification service deleted successfully",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete notification service {service}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to delete notification service: {str(e)}"
|
||||
) from e
|
||||
212
leggen/api/routes/sync.py
Normal file
212
leggen/api/routes/sync.py
Normal file
@@ -0,0 +1,212 @@
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||
from loguru import logger
|
||||
|
||||
from leggen.api.models.common import APIResponse
|
||||
from leggen.api.models.sync import SyncRequest, SchedulerConfig
|
||||
from leggen.services.sync_service import SyncService
|
||||
from leggen.background.scheduler import scheduler
|
||||
from leggen.utils.config import config
|
||||
|
||||
router = APIRouter()
|
||||
sync_service = SyncService()
|
||||
|
||||
|
||||
@router.get("/sync/status", response_model=APIResponse)
|
||||
async def get_sync_status() -> APIResponse:
|
||||
"""Get current sync status"""
|
||||
try:
|
||||
status = await sync_service.get_sync_status()
|
||||
|
||||
# Add scheduler information
|
||||
next_sync_time = scheduler.get_next_sync_time()
|
||||
if next_sync_time:
|
||||
status.next_sync = next_sync_time
|
||||
|
||||
return APIResponse(
|
||||
success=True, data=status, message="Sync status retrieved successfully"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get sync status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get sync status: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/sync", response_model=APIResponse)
|
||||
async def trigger_sync(
|
||||
background_tasks: BackgroundTasks, sync_request: Optional[SyncRequest] = None
|
||||
) -> APIResponse:
|
||||
"""Trigger a manual sync operation"""
|
||||
try:
|
||||
# Check if sync is already running
|
||||
status = await sync_service.get_sync_status()
|
||||
if status.is_running and not (sync_request and sync_request.force):
|
||||
return APIResponse(
|
||||
success=False,
|
||||
message="Sync is already running. Use 'force: true' to override.",
|
||||
)
|
||||
|
||||
# Determine what to sync
|
||||
if sync_request and sync_request.account_ids:
|
||||
# Sync specific accounts in background
|
||||
background_tasks.add_task(
|
||||
sync_service.sync_specific_accounts,
|
||||
sync_request.account_ids,
|
||||
sync_request.force if sync_request else False,
|
||||
)
|
||||
message = (
|
||||
f"Started sync for {len(sync_request.account_ids)} specific accounts"
|
||||
)
|
||||
else:
|
||||
# Sync all accounts in background
|
||||
background_tasks.add_task(
|
||||
sync_service.sync_all_accounts,
|
||||
sync_request.force if sync_request else False,
|
||||
)
|
||||
message = "Started sync for all accounts"
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data={
|
||||
"sync_started": True,
|
||||
"force": sync_request.force if sync_request else False,
|
||||
},
|
||||
message=message,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trigger sync: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to trigger sync: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/sync/now", response_model=APIResponse)
|
||||
async def sync_now(sync_request: Optional[SyncRequest] = None) -> APIResponse:
|
||||
"""Run sync synchronously and return results (slower, for testing)"""
|
||||
try:
|
||||
if sync_request and sync_request.account_ids:
|
||||
result = await sync_service.sync_specific_accounts(
|
||||
sync_request.account_ids, sync_request.force
|
||||
)
|
||||
else:
|
||||
result = await sync_service.sync_all_accounts(
|
||||
sync_request.force if sync_request else False
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
success=result.success,
|
||||
data=result,
|
||||
message="Sync completed"
|
||||
if result.success
|
||||
else f"Sync failed with {len(result.errors)} errors",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to run sync: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to run sync: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/sync/scheduler", response_model=APIResponse)
|
||||
async def get_scheduler_config() -> APIResponse:
|
||||
"""Get current scheduler configuration"""
|
||||
try:
|
||||
scheduler_config = config.scheduler_config
|
||||
next_sync_time = scheduler.get_next_sync_time()
|
||||
|
||||
response_data = {
|
||||
**scheduler_config,
|
||||
"next_scheduled_sync": next_sync_time.isoformat()
|
||||
if next_sync_time
|
||||
else None,
|
||||
"is_running": scheduler.scheduler.running
|
||||
if hasattr(scheduler, "scheduler")
|
||||
else False,
|
||||
}
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=response_data,
|
||||
message="Scheduler configuration retrieved successfully",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get scheduler config: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get scheduler config: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put("/sync/scheduler", response_model=APIResponse)
|
||||
async def update_scheduler_config(scheduler_config: SchedulerConfig) -> APIResponse:
|
||||
"""Update scheduler configuration"""
|
||||
try:
|
||||
# Validate cron expression if provided
|
||||
if scheduler_config.cron:
|
||||
try:
|
||||
cron_parts = scheduler_config.cron.split()
|
||||
if len(cron_parts) != 5:
|
||||
raise ValueError(
|
||||
"Cron expression must have 5 parts: minute hour day month day_of_week"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid cron expression: {str(e)}"
|
||||
) from e
|
||||
|
||||
# Update configuration
|
||||
schedule_data = scheduler_config.dict(exclude_none=True)
|
||||
config.update_section("scheduler", {"sync": schedule_data})
|
||||
|
||||
# Reschedule the job
|
||||
scheduler.reschedule_sync(schedule_data)
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=schedule_data,
|
||||
message="Scheduler configuration updated successfully",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update scheduler config: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to update scheduler config: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/sync/scheduler/start", response_model=APIResponse)
|
||||
async def start_scheduler() -> APIResponse:
|
||||
"""Start the background scheduler"""
|
||||
try:
|
||||
if not scheduler.scheduler.running:
|
||||
scheduler.start()
|
||||
return APIResponse(success=True, message="Scheduler started successfully")
|
||||
else:
|
||||
return APIResponse(success=True, message="Scheduler is already running")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start scheduler: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to start scheduler: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/sync/scheduler/stop", response_model=APIResponse)
|
||||
async def stop_scheduler() -> APIResponse:
|
||||
"""Stop the background scheduler"""
|
||||
try:
|
||||
if scheduler.scheduler.running:
|
||||
scheduler.shutdown()
|
||||
return APIResponse(success=True, message="Scheduler stopped successfully")
|
||||
else:
|
||||
return APIResponse(success=True, message="Scheduler is already stopped")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop scheduler: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to stop scheduler: {str(e)}"
|
||||
) from e
|
||||
254
leggen/api/routes/transactions.py
Normal file
254
leggen/api/routes/transactions.py
Normal file
@@ -0,0 +1,254 @@
|
||||
from typing import Optional, List, Union
|
||||
from datetime import datetime, timedelta
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from loguru import logger
|
||||
|
||||
from leggen.api.models.common import APIResponse, PaginatedResponse
|
||||
from leggen.api.models.accounts import Transaction, TransactionSummary
|
||||
from leggen.services.database_service import DatabaseService
|
||||
|
||||
router = APIRouter()
|
||||
database_service = DatabaseService()
|
||||
|
||||
|
||||
@router.get("/transactions", response_model=PaginatedResponse)
|
||||
async def get_all_transactions(
|
||||
page: int = Query(default=1, ge=1, description="Page number (1-based)"),
|
||||
per_page: int = Query(default=50, le=500, description="Items per page"),
|
||||
summary_only: bool = Query(
|
||||
default=True, description="Return transaction summaries only"
|
||||
),
|
||||
date_from: Optional[str] = Query(
|
||||
default=None, description="Filter from date (YYYY-MM-DD)"
|
||||
),
|
||||
date_to: Optional[str] = Query(
|
||||
default=None, description="Filter to date (YYYY-MM-DD)"
|
||||
),
|
||||
min_amount: Optional[float] = Query(
|
||||
default=None, description="Minimum transaction amount"
|
||||
),
|
||||
max_amount: Optional[float] = Query(
|
||||
default=None, description="Maximum transaction amount"
|
||||
),
|
||||
search: Optional[str] = Query(
|
||||
default=None, description="Search in transaction descriptions"
|
||||
),
|
||||
account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
|
||||
) -> PaginatedResponse:
|
||||
"""Get all transactions from database with filtering options"""
|
||||
try:
|
||||
# Calculate offset from page and per_page
|
||||
offset = (page - 1) * per_page
|
||||
limit = per_page
|
||||
|
||||
# Get transactions from database instead of GoCardless API
|
||||
db_transactions = await database_service.get_transactions_from_db(
|
||||
account_id=account_id,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
date_from=date_from,
|
||||
date_to=date_to,
|
||||
min_amount=min_amount,
|
||||
max_amount=max_amount,
|
||||
search=search,
|
||||
)
|
||||
|
||||
# Get total count for pagination info (respecting the same filters)
|
||||
total_transactions = await database_service.get_transaction_count_from_db(
|
||||
account_id=account_id,
|
||||
date_from=date_from,
|
||||
date_to=date_to,
|
||||
min_amount=min_amount,
|
||||
max_amount=max_amount,
|
||||
search=search,
|
||||
)
|
||||
|
||||
data: Union[List[TransactionSummary], List[Transaction]]
|
||||
|
||||
if summary_only:
|
||||
# Return simplified transaction summaries
|
||||
data = [
|
||||
TransactionSummary(
|
||||
transaction_id=txn["transactionId"], # NEW: stable bank-provided ID
|
||||
internal_transaction_id=txn.get("internalTransactionId"),
|
||||
date=txn["transactionDate"],
|
||||
description=txn["description"],
|
||||
amount=txn["transactionValue"],
|
||||
currency=txn["transactionCurrency"],
|
||||
status=txn["transactionStatus"],
|
||||
account_id=txn["accountId"],
|
||||
)
|
||||
for txn in db_transactions
|
||||
]
|
||||
else:
|
||||
# Return full transaction details
|
||||
data = [
|
||||
Transaction(
|
||||
transaction_id=txn["transactionId"], # NEW: stable bank-provided ID
|
||||
internal_transaction_id=txn.get("internalTransactionId"),
|
||||
institution_id=txn["institutionId"],
|
||||
iban=txn["iban"],
|
||||
account_id=txn["accountId"],
|
||||
transaction_date=txn["transactionDate"],
|
||||
description=txn["description"],
|
||||
transaction_value=txn["transactionValue"],
|
||||
transaction_currency=txn["transactionCurrency"],
|
||||
transaction_status=txn["transactionStatus"],
|
||||
raw_transaction=txn["rawTransaction"],
|
||||
)
|
||||
for txn in db_transactions
|
||||
]
|
||||
|
||||
total_pages = (total_transactions + per_page - 1) // per_page
|
||||
|
||||
return PaginatedResponse(
|
||||
success=True,
|
||||
data=data,
|
||||
pagination={
|
||||
"total": total_transactions,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"total_pages": total_pages,
|
||||
"has_next": page < total_pages,
|
||||
"has_prev": page > 1,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get transactions from database: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get transactions: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/transactions/stats", response_model=APIResponse)
|
||||
async def get_transaction_stats(
|
||||
days: int = Query(default=30, description="Number of days to include in stats"),
|
||||
account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
|
||||
) -> APIResponse:
|
||||
"""Get transaction statistics for the last N days from database"""
|
||||
try:
|
||||
# Date range for stats
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=days)
|
||||
|
||||
# Format dates for database query
|
||||
date_from = start_date.isoformat()
|
||||
date_to = end_date.isoformat()
|
||||
|
||||
# Get transactions from database
|
||||
recent_transactions = await database_service.get_transactions_from_db(
|
||||
account_id=account_id,
|
||||
date_from=date_from,
|
||||
date_to=date_to,
|
||||
limit=None, # Get all matching transactions for stats
|
||||
)
|
||||
|
||||
# Calculate stats
|
||||
total_transactions = len(recent_transactions)
|
||||
total_income = sum(
|
||||
txn["transactionValue"]
|
||||
for txn in recent_transactions
|
||||
if txn["transactionValue"] > 0
|
||||
)
|
||||
total_expenses = sum(
|
||||
abs(txn["transactionValue"])
|
||||
for txn in recent_transactions
|
||||
if txn["transactionValue"] < 0
|
||||
)
|
||||
net_change = total_income - total_expenses
|
||||
|
||||
# Count by status
|
||||
booked_count = len(
|
||||
[txn for txn in recent_transactions if txn["transactionStatus"] == "booked"]
|
||||
)
|
||||
pending_count = len(
|
||||
[
|
||||
txn
|
||||
for txn in recent_transactions
|
||||
if txn["transactionStatus"] == "pending"
|
||||
]
|
||||
)
|
||||
|
||||
# Count unique accounts
|
||||
unique_accounts = len({txn["accountId"] for txn in recent_transactions})
|
||||
|
||||
stats = {
|
||||
"period_days": days,
|
||||
"total_transactions": total_transactions,
|
||||
"booked_transactions": booked_count,
|
||||
"pending_transactions": pending_count,
|
||||
"total_income": round(total_income, 2),
|
||||
"total_expenses": round(total_expenses, 2),
|
||||
"net_change": round(net_change, 2),
|
||||
"average_transaction": round(
|
||||
sum(txn["transactionValue"] for txn in recent_transactions)
|
||||
/ total_transactions,
|
||||
2,
|
||||
)
|
||||
if total_transactions > 0
|
||||
else 0,
|
||||
"accounts_included": unique_accounts,
|
||||
}
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=stats,
|
||||
message=f"Transaction statistics for last {days} days",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get transaction stats from database: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get transaction stats: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/transactions/analytics", response_model=APIResponse)
|
||||
async def get_transactions_for_analytics(
|
||||
days: int = Query(default=365, description="Number of days to include"),
|
||||
account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
|
||||
) -> APIResponse:
|
||||
"""Get all transactions for analytics (no pagination) for the last N days"""
|
||||
try:
|
||||
# Date range for analytics
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=days)
|
||||
|
||||
# Format dates for database query
|
||||
date_from = start_date.isoformat()
|
||||
date_to = end_date.isoformat()
|
||||
|
||||
# Get ALL transactions from database (no limit for analytics)
|
||||
transactions = await database_service.get_transactions_from_db(
|
||||
account_id=account_id,
|
||||
date_from=date_from,
|
||||
date_to=date_to,
|
||||
limit=None, # No limit - get all transactions
|
||||
)
|
||||
|
||||
# Transform for frontend (summary format)
|
||||
transaction_summaries = [
|
||||
{
|
||||
"transaction_id": txn["transactionId"],
|
||||
"date": txn["transactionDate"],
|
||||
"description": txn["description"],
|
||||
"amount": txn["transactionValue"],
|
||||
"currency": txn["transactionCurrency"],
|
||||
"status": txn["transactionStatus"],
|
||||
"account_id": txn["accountId"],
|
||||
}
|
||||
for txn in transactions
|
||||
]
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data=transaction_summaries,
|
||||
message=f"Retrieved {len(transaction_summaries)} transactions for analytics",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get transactions for analytics: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get analytics transactions: {str(e)}"
|
||||
) from e
|
||||
@@ -6,15 +6,15 @@ from urllib.parse import urljoin
|
||||
from leggen.utils.text import error
|
||||
|
||||
|
||||
class LeggendAPIClient:
|
||||
"""Client for communicating with the leggend FastAPI service"""
|
||||
class LeggenAPIClient:
|
||||
"""Client for communicating with the leggen FastAPI service"""
|
||||
|
||||
base_url: str
|
||||
|
||||
def __init__(self, base_url: Optional[str] = None):
|
||||
self.base_url = (
|
||||
base_url
|
||||
or os.environ.get("LEGGEND_API_URL", "http://localhost:8000")
|
||||
or os.environ.get("LEGGEN_API_URL", "http://localhost:8000")
|
||||
or "http://localhost:8000"
|
||||
)
|
||||
self.session = requests.Session()
|
||||
@@ -31,7 +31,7 @@ class LeggendAPIClient:
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.ConnectionError:
|
||||
error("Could not connect to leggend service. Is it running?")
|
||||
error("Could not connect to leggen server. Is it running?")
|
||||
error(f"Trying to connect to: {self.base_url}")
|
||||
raise
|
||||
except requests.exceptions.HTTPError as e:
|
||||
@@ -48,7 +48,7 @@ class LeggendAPIClient:
|
||||
raise
|
||||
|
||||
def health_check(self) -> bool:
|
||||
"""Check if the leggend service is healthy"""
|
||||
"""Check if the leggen server is healthy"""
|
||||
try:
|
||||
response = self._make_request("GET", "/health")
|
||||
return response.get("status") == "healthy"
|
||||
|
||||
168
leggen/background/scheduler.py
Normal file
168
leggen/background/scheduler.py
Normal file
@@ -0,0 +1,168 @@
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from loguru import logger
|
||||
|
||||
from leggen.utils.config import config
|
||||
from leggen.services.sync_service import SyncService
|
||||
from leggen.services.notification_service import NotificationService
|
||||
|
||||
|
||||
class BackgroundScheduler:
|
||||
def __init__(self):
|
||||
self.scheduler = AsyncIOScheduler()
|
||||
self.sync_service = SyncService()
|
||||
self.notification_service = NotificationService()
|
||||
self.max_retries = 3
|
||||
self.retry_delay = 300 # 5 minutes
|
||||
|
||||
def start(self):
|
||||
"""Start the scheduler and configure sync jobs based on configuration"""
|
||||
schedule_config = config.scheduler_config.get("sync", {})
|
||||
|
||||
if not schedule_config.get("enabled", True):
|
||||
logger.info("Sync scheduling is disabled in configuration")
|
||||
self.scheduler.start()
|
||||
return
|
||||
|
||||
# Parse schedule configuration
|
||||
trigger = self._parse_cron_config(schedule_config)
|
||||
if not trigger:
|
||||
return
|
||||
|
||||
self.scheduler.add_job(
|
||||
self._run_sync,
|
||||
trigger,
|
||||
id="daily_sync",
|
||||
name="Scheduled sync of all transactions",
|
||||
max_instances=1,
|
||||
)
|
||||
|
||||
self.scheduler.start()
|
||||
logger.info(f"Background scheduler started with sync job: {trigger}")
|
||||
|
||||
def shutdown(self):
|
||||
if self.scheduler.running:
|
||||
self.scheduler.shutdown()
|
||||
logger.info("Background scheduler shutdown")
|
||||
|
||||
def reschedule_sync(self, schedule_config: dict):
|
||||
"""Reschedule the sync job with new configuration"""
|
||||
if self.scheduler.running:
|
||||
try:
|
||||
self.scheduler.remove_job("daily_sync")
|
||||
logger.info("Removed existing sync job")
|
||||
except Exception:
|
||||
pass # Job might not exist
|
||||
|
||||
if not schedule_config.get("enabled", True):
|
||||
logger.info("Sync scheduling disabled")
|
||||
return
|
||||
|
||||
# Configure new schedule
|
||||
trigger = self._parse_cron_config(schedule_config)
|
||||
if not trigger:
|
||||
return
|
||||
|
||||
self.scheduler.add_job(
|
||||
self._run_sync,
|
||||
trigger,
|
||||
id="daily_sync",
|
||||
name="Scheduled sync of all transactions",
|
||||
max_instances=1,
|
||||
)
|
||||
logger.info(f"Rescheduled sync job with: {trigger}")
|
||||
|
||||
def _parse_cron_config(self, schedule_config: dict) -> CronTrigger:
|
||||
"""Parse cron configuration and return CronTrigger"""
|
||||
if schedule_config.get("cron"):
|
||||
# Parse custom cron expression (e.g., "0 3 * * *" for daily at 3 AM)
|
||||
try:
|
||||
cron_parts = schedule_config["cron"].split()
|
||||
if len(cron_parts) == 5:
|
||||
minute, hour, day, month, day_of_week = cron_parts
|
||||
return CronTrigger(
|
||||
minute=minute,
|
||||
hour=hour,
|
||||
day=day if day != "*" else None,
|
||||
month=month if month != "*" else None,
|
||||
day_of_week=day_of_week if day_of_week != "*" else None,
|
||||
)
|
||||
else:
|
||||
logger.error(f"Invalid cron expression: {schedule_config['cron']}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing cron expression: {e}")
|
||||
return None
|
||||
else:
|
||||
# Use hour/minute configuration (default: 3:00 AM daily)
|
||||
hour = schedule_config.get("hour", 3)
|
||||
minute = schedule_config.get("minute", 0)
|
||||
return CronTrigger(hour=hour, minute=minute)
|
||||
|
||||
async def _run_sync(self, retry_count: int = 0):
|
||||
"""Run sync with enhanced error handling and retry logic"""
|
||||
try:
|
||||
logger.info("Starting scheduled sync job")
|
||||
await self.sync_service.sync_all_accounts()
|
||||
logger.info("Scheduled sync job completed successfully")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Scheduled sync job failed (attempt {retry_count + 1}/{self.max_retries}): {e}"
|
||||
)
|
||||
|
||||
# Send notification about the failure
|
||||
try:
|
||||
await self.notification_service.send_expiry_notification(
|
||||
{
|
||||
"type": "sync_failure",
|
||||
"error": str(e),
|
||||
"retry_count": retry_count + 1,
|
||||
"max_retries": self.max_retries,
|
||||
}
|
||||
)
|
||||
except Exception as notification_error:
|
||||
logger.error(
|
||||
f"Failed to send failure notification: {notification_error}"
|
||||
)
|
||||
|
||||
# Implement retry logic for transient failures
|
||||
if retry_count < self.max_retries - 1:
|
||||
import datetime
|
||||
|
||||
logger.info(f"Retrying sync job in {self.retry_delay} seconds...")
|
||||
# Schedule a retry
|
||||
retry_time = datetime.datetime.now() + datetime.timedelta(
|
||||
seconds=self.retry_delay
|
||||
)
|
||||
self.scheduler.add_job(
|
||||
self._run_sync,
|
||||
"date",
|
||||
args=[retry_count + 1],
|
||||
id=f"sync_retry_{retry_count + 1}",
|
||||
run_date=retry_time,
|
||||
)
|
||||
else:
|
||||
logger.error("Maximum retries exceeded for sync job")
|
||||
# Send final failure notification
|
||||
try:
|
||||
await self.notification_service.send_expiry_notification(
|
||||
{
|
||||
"type": "sync_final_failure",
|
||||
"error": str(e),
|
||||
"retry_count": retry_count + 1,
|
||||
}
|
||||
)
|
||||
except Exception as notification_error:
|
||||
logger.error(
|
||||
f"Failed to send final failure notification: {notification_error}"
|
||||
)
|
||||
|
||||
def get_next_sync_time(self):
|
||||
"""Get the next scheduled sync time"""
|
||||
job = self.scheduler.get_job("daily_sync")
|
||||
if job:
|
||||
return job.next_run_time
|
||||
return None
|
||||
|
||||
|
||||
scheduler = BackgroundScheduler()
|
||||
@@ -1,7 +1,7 @@
|
||||
import click
|
||||
|
||||
from leggen.main import cli
|
||||
from leggen.api_client import LeggendAPIClient
|
||||
from leggen.api_client import LeggenAPIClient
|
||||
from leggen.utils.text import datefmt, print_table
|
||||
|
||||
|
||||
@@ -11,12 +11,12 @@ def balances(ctx: click.Context):
|
||||
"""
|
||||
List balances of all connected accounts
|
||||
"""
|
||||
api_client = LeggendAPIClient(ctx.obj.get("api_url"))
|
||||
api_client = LeggenAPIClient(ctx.obj.get("api_url"))
|
||||
|
||||
# Check if leggend service is available
|
||||
# Check if leggen server is available
|
||||
if not api_client.health_check():
|
||||
click.echo(
|
||||
"Error: Cannot connect to leggend service. Please ensure it's running."
|
||||
"Error: Cannot connect to leggen server. Please ensure it's running."
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import click
|
||||
|
||||
from leggen.main import cli
|
||||
from leggen.api_client import LeggendAPIClient
|
||||
from leggen.api_client import LeggenAPIClient
|
||||
from leggen.utils.disk import save_file
|
||||
from leggen.utils.text import info, print_table, warning, success
|
||||
|
||||
@@ -12,12 +12,12 @@ def add(ctx):
|
||||
"""
|
||||
Connect to a bank
|
||||
"""
|
||||
api_client = LeggendAPIClient(ctx.obj.get("api_url"))
|
||||
api_client = LeggenAPIClient(ctx.obj.get("api_url"))
|
||||
|
||||
# Check if leggend service is available
|
||||
# Check if leggen server is available
|
||||
if not api_client.health_check():
|
||||
click.echo(
|
||||
"Error: Cannot connect to leggend service. Please ensure it's running."
|
||||
"Error: Cannot connect to leggen server. Please ensure it's running."
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import click
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--database",
|
||||
|
||||
178
leggen/commands/server.py
Normal file
178
leggen/commands/server.py
Normal file
@@ -0,0 +1,178 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from importlib import metadata
|
||||
|
||||
import click
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from loguru import logger
|
||||
|
||||
from leggen.api.routes import banks, accounts, sync, notifications, transactions
|
||||
from leggen.background.scheduler import scheduler
|
||||
from leggen.utils.config import config
|
||||
from leggen.utils.paths import path_manager
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup
|
||||
logger.info("Starting leggen server...")
|
||||
|
||||
# Load configuration
|
||||
try:
|
||||
config.load_config()
|
||||
logger.info("Configuration loaded successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load configuration: {e}")
|
||||
raise
|
||||
|
||||
# Run database migrations
|
||||
try:
|
||||
from leggen.services.database_service import DatabaseService
|
||||
|
||||
db_service = DatabaseService()
|
||||
await db_service.run_migrations_if_needed()
|
||||
logger.info("Database migrations completed")
|
||||
except Exception as e:
|
||||
logger.error(f"Database migration failed: {e}")
|
||||
raise
|
||||
|
||||
# Start background scheduler
|
||||
scheduler.start()
|
||||
logger.info("Background scheduler started")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down leggen server...")
|
||||
scheduler.shutdown()
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
# Get version dynamically from package metadata
|
||||
try:
|
||||
version = metadata.version("leggen")
|
||||
except metadata.PackageNotFoundError:
|
||||
version = "unknown"
|
||||
|
||||
app = FastAPI(
|
||||
title="Leggen API",
|
||||
description="Open Banking API for Leggen",
|
||||
version=version,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[
|
||||
"http://localhost:3000",
|
||||
"http://localhost:5173",
|
||||
"http://frontend:80",
|
||||
], # Frontend container and dev servers
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include API routes
|
||||
app.include_router(banks.router, prefix="/api/v1", tags=["banks"])
|
||||
app.include_router(accounts.router, prefix="/api/v1", tags=["accounts"])
|
||||
app.include_router(transactions.router, prefix="/api/v1", tags=["transactions"])
|
||||
app.include_router(sync.router, prefix="/api/v1", tags=["sync"])
|
||||
app.include_router(notifications.router, prefix="/api/v1", tags=["notifications"])
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
# Get version dynamically
|
||||
try:
|
||||
version = metadata.version("leggen")
|
||||
except metadata.PackageNotFoundError:
|
||||
version = "unknown"
|
||||
return {"message": "Leggen API is running", "version": version}
|
||||
|
||||
@app.get("/api/v1/health")
|
||||
async def health():
|
||||
"""Health check endpoint for API connectivity"""
|
||||
try:
|
||||
from leggen.api.models.common import APIResponse
|
||||
|
||||
config_loaded = config._config is not None
|
||||
|
||||
return APIResponse(
|
||||
success=True,
|
||||
data={
|
||||
"status": "healthy",
|
||||
"config_loaded": config_loaded,
|
||||
"message": "API is running and responsive",
|
||||
},
|
||||
message="Health check successful",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Health check failed: {e}")
|
||||
from leggen.api.models.common import APIResponse
|
||||
|
||||
return APIResponse(
|
||||
success=False,
|
||||
data={"status": "unhealthy", "error": str(e)},
|
||||
message="Health check failed",
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--reload",
|
||||
is_flag=True,
|
||||
help="Enable auto-reload for development",
|
||||
)
|
||||
@click.option(
|
||||
"--host",
|
||||
default="0.0.0.0",
|
||||
help="Host to bind to (default: 0.0.0.0)",
|
||||
)
|
||||
@click.option(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port to bind to (default: 8000)",
|
||||
)
|
||||
@click.pass_context
|
||||
def server(ctx: click.Context, reload: bool, host: str, port: int):
|
||||
"""Start the Leggen API server"""
|
||||
|
||||
# Get config_dir and database from main CLI context
|
||||
config_dir = None
|
||||
database = None
|
||||
if ctx.parent:
|
||||
config_dir = ctx.parent.params.get("config_dir")
|
||||
database = ctx.parent.params.get("database")
|
||||
|
||||
# Set up path manager with user-provided paths
|
||||
if config_dir:
|
||||
path_manager.set_config_dir(config_dir)
|
||||
if database:
|
||||
path_manager.set_database_path(database)
|
||||
|
||||
if reload:
|
||||
# Use string import for reload to work properly
|
||||
uvicorn.run(
|
||||
"leggen.commands.server:create_app",
|
||||
factory=True,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level="info",
|
||||
access_log=True,
|
||||
reload=True,
|
||||
reload_dirs=["leggen"], # Watch leggen directory
|
||||
)
|
||||
else:
|
||||
app = create_app()
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level="info",
|
||||
access_log=True,
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
import click
|
||||
|
||||
from leggen.main import cli
|
||||
from leggen.api_client import LeggendAPIClient
|
||||
from leggen.api_client import LeggenAPIClient
|
||||
from leggen.utils.text import datefmt, echo, info, print_table
|
||||
|
||||
|
||||
@@ -11,12 +11,12 @@ def status(ctx: click.Context):
|
||||
"""
|
||||
List all connected banks and their status
|
||||
"""
|
||||
api_client = LeggendAPIClient(ctx.obj.get("api_url"))
|
||||
api_client = LeggenAPIClient(ctx.obj.get("api_url"))
|
||||
|
||||
# Check if leggend service is available
|
||||
# Check if leggen server is available
|
||||
if not api_client.health_check():
|
||||
click.echo(
|
||||
"Error: Cannot connect to leggend service. Please ensure it's running."
|
||||
"Error: Cannot connect to leggen server. Please ensure it's running."
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import click
|
||||
|
||||
from leggen.main import cli
|
||||
from leggen.api_client import LeggendAPIClient
|
||||
from leggen.api_client import LeggenAPIClient
|
||||
from leggen.utils.text import error, info, success
|
||||
|
||||
|
||||
@@ -13,11 +13,11 @@ def sync(ctx: click.Context, wait: bool, force: bool):
|
||||
"""
|
||||
Sync all transactions with database
|
||||
"""
|
||||
api_client = LeggendAPIClient(ctx.obj.get("api_url"))
|
||||
api_client = LeggenAPIClient(ctx.obj.get("api_url"))
|
||||
|
||||
# Check if leggend service is available
|
||||
# Check if leggen server is available
|
||||
if not api_client.health_check():
|
||||
error("Cannot connect to leggend service. Please ensure it's running.")
|
||||
error("Cannot connect to leggen server. Please ensure it's running.")
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import click
|
||||
|
||||
from leggen.main import cli
|
||||
from leggen.api_client import LeggendAPIClient
|
||||
from leggen.api_client import LeggenAPIClient
|
||||
from leggen.utils.text import datefmt, info, print_table
|
||||
|
||||
|
||||
@@ -20,12 +20,12 @@ def transactions(ctx: click.Context, account: str, limit: int, full: bool):
|
||||
|
||||
If the --account option is used, it will only list transactions for that account.
|
||||
"""
|
||||
api_client = LeggendAPIClient(ctx.obj.get("api_url"))
|
||||
api_client = LeggenAPIClient(ctx.obj.get("api_url"))
|
||||
|
||||
# Check if leggend service is available
|
||||
# Check if leggen server is available
|
||||
if not api_client.health_check():
|
||||
click.echo(
|
||||
"Error: Cannot connect to leggend service. Please ensure it's running."
|
||||
"Error: Cannot connect to leggen server. Please ensure it's running."
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@@ -527,11 +527,11 @@ def get_historical_balances(account_id=None, days=365):
|
||||
db_path = path_manager.get_database_path()
|
||||
if not db_path.exists():
|
||||
return []
|
||||
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.cursor()
|
||||
|
||||
|
||||
try:
|
||||
# Get current balance for each account/type to use as the final balance
|
||||
current_balances_query = """
|
||||
@@ -544,107 +544,115 @@ def get_historical_balances(account_id=None, days=365):
|
||||
)
|
||||
"""
|
||||
params = []
|
||||
|
||||
|
||||
if account_id:
|
||||
current_balances_query += " AND b1.account_id = ?"
|
||||
params.append(account_id)
|
||||
|
||||
|
||||
cursor.execute(current_balances_query, params)
|
||||
current_balances = {
|
||||
(row['account_id'], row['type']): {
|
||||
'amount': row['amount'],
|
||||
'currency': row['currency']
|
||||
(row["account_id"], row["type"]): {
|
||||
"amount": row["amount"],
|
||||
"currency": row["currency"],
|
||||
}
|
||||
for row in cursor.fetchall()
|
||||
}
|
||||
|
||||
|
||||
# Get transactions for the specified period, ordered by date descending
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
cutoff_date = (datetime.now() - timedelta(days=days)).isoformat()
|
||||
|
||||
|
||||
transactions_query = """
|
||||
SELECT accountId, transactionDate, transactionValue
|
||||
FROM transactions
|
||||
WHERE transactionDate >= ?
|
||||
"""
|
||||
|
||||
|
||||
if account_id:
|
||||
transactions_query += " AND accountId = ?"
|
||||
params = [cutoff_date, account_id]
|
||||
else:
|
||||
params = [cutoff_date]
|
||||
|
||||
|
||||
transactions_query += " ORDER BY transactionDate DESC"
|
||||
|
||||
|
||||
cursor.execute(transactions_query, params)
|
||||
transactions = cursor.fetchall()
|
||||
|
||||
|
||||
# Calculate historical balances by working backwards from current balance
|
||||
historical_balances = []
|
||||
account_running_balances: dict[str, dict[str, float]] = {}
|
||||
|
||||
|
||||
# Initialize running balances with current balances
|
||||
for (acc_id, balance_type), balance_info in current_balances.items():
|
||||
if acc_id not in account_running_balances:
|
||||
account_running_balances[acc_id] = {}
|
||||
account_running_balances[acc_id][balance_type] = balance_info['amount']
|
||||
|
||||
account_running_balances[acc_id][balance_type] = balance_info["amount"]
|
||||
|
||||
# Group transactions by date
|
||||
from collections import defaultdict
|
||||
|
||||
transactions_by_date = defaultdict(list)
|
||||
|
||||
|
||||
for txn in transactions:
|
||||
date_str = txn['transactionDate'][:10] # Extract just the date part
|
||||
date_str = txn["transactionDate"][:10] # Extract just the date part
|
||||
transactions_by_date[date_str].append(txn)
|
||||
|
||||
|
||||
# Generate historical balance points
|
||||
# Start from today and work backwards
|
||||
current_date = datetime.now().date()
|
||||
|
||||
|
||||
for day_offset in range(0, days, 7): # Sample every 7 days for performance
|
||||
target_date = current_date - timedelta(days=day_offset)
|
||||
target_date_str = target_date.isoformat()
|
||||
|
||||
|
||||
# For each account, create balance entries
|
||||
for acc_id in account_running_balances:
|
||||
for balance_type in ['closingBooked']: # Focus on closingBooked for the chart
|
||||
for balance_type in [
|
||||
"closingBooked"
|
||||
]: # Focus on closingBooked for the chart
|
||||
if balance_type in account_running_balances[acc_id]:
|
||||
balance_amount = account_running_balances[acc_id][balance_type]
|
||||
currency = current_balances.get((acc_id, balance_type), {}).get('currency', 'EUR')
|
||||
|
||||
historical_balances.append({
|
||||
'id': f"{acc_id}_{balance_type}_{target_date_str}",
|
||||
'account_id': acc_id,
|
||||
'balance_amount': balance_amount,
|
||||
'balance_type': balance_type,
|
||||
'currency': currency,
|
||||
'reference_date': target_date_str,
|
||||
'created_at': None,
|
||||
'updated_at': None
|
||||
})
|
||||
|
||||
currency = current_balances.get((acc_id, balance_type), {}).get(
|
||||
"currency", "EUR"
|
||||
)
|
||||
|
||||
historical_balances.append(
|
||||
{
|
||||
"id": f"{acc_id}_{balance_type}_{target_date_str}",
|
||||
"account_id": acc_id,
|
||||
"balance_amount": balance_amount,
|
||||
"balance_type": balance_type,
|
||||
"currency": currency,
|
||||
"reference_date": target_date_str,
|
||||
"created_at": None,
|
||||
"updated_at": None,
|
||||
}
|
||||
)
|
||||
|
||||
# Subtract transactions that occurred on this date and later dates
|
||||
# to simulate going back in time
|
||||
for date_str in list(transactions_by_date.keys()):
|
||||
if date_str >= target_date_str:
|
||||
for txn in transactions_by_date[date_str]:
|
||||
acc_id = txn['accountId']
|
||||
amount = txn['transactionValue']
|
||||
|
||||
acc_id = txn["accountId"]
|
||||
amount = txn["transactionValue"]
|
||||
|
||||
if acc_id in account_running_balances:
|
||||
for balance_type in account_running_balances[acc_id]:
|
||||
account_running_balances[acc_id][balance_type] -= amount
|
||||
|
||||
|
||||
# Remove processed transactions to avoid double-processing
|
||||
del transactions_by_date[date_str]
|
||||
|
||||
|
||||
conn.close()
|
||||
|
||||
|
||||
# Sort by date for proper chronological order
|
||||
historical_balances.sort(key=lambda x: x['reference_date'])
|
||||
|
||||
historical_balances.sort(key=lambda x: x["reference_date"])
|
||||
|
||||
return historical_balances
|
||||
|
||||
|
||||
except Exception as e:
|
||||
conn.close()
|
||||
raise e
|
||||
|
||||
@@ -105,9 +105,9 @@ class Group(click.Group):
|
||||
"--api-url",
|
||||
type=str,
|
||||
default="http://localhost:8000",
|
||||
envvar="LEGGEND_API_URL",
|
||||
envvar="LEGGEN_API_URL",
|
||||
show_envvar=True,
|
||||
help="URL of the leggend API service",
|
||||
help="URL of the leggen API service",
|
||||
)
|
||||
@click.group(
|
||||
cls=Group,
|
||||
|
||||
903
leggen/services/database_service.py
Normal file
903
leggen/services/database_service.py
Normal file
@@ -0,0 +1,903 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
import sqlite3
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from leggen.utils.config import config
|
||||
import leggen.database.sqlite as sqlite_db
|
||||
from leggen.utils.paths import path_manager
|
||||
|
||||
|
||||
class DatabaseService:
|
||||
def __init__(self):
|
||||
self.db_config = config.database_config
|
||||
self.sqlite_enabled = self.db_config.get("sqlite", True)
|
||||
|
||||
async def persist_balance(
|
||||
self, account_id: str, balance_data: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Persist account balance data"""
|
||||
if not self.sqlite_enabled:
|
||||
logger.warning("SQLite database disabled, skipping balance persistence")
|
||||
return
|
||||
|
||||
await self._persist_balance_sqlite(account_id, balance_data)
|
||||
|
||||
async def persist_transactions(
|
||||
self, account_id: str, transactions: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Persist transactions and return new transactions"""
|
||||
if not self.sqlite_enabled:
|
||||
logger.warning("SQLite database disabled, skipping transaction persistence")
|
||||
return transactions
|
||||
|
||||
return await self._persist_transactions_sqlite(account_id, transactions)
|
||||
|
||||
def process_transactions(
|
||||
self,
|
||||
account_id: str,
|
||||
account_info: Dict[str, Any],
|
||||
transaction_data: Dict[str, Any],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Process raw transaction data into standardized format"""
|
||||
transactions = []
|
||||
|
||||
# Process booked transactions
|
||||
for transaction in transaction_data.get("transactions", {}).get("booked", []):
|
||||
processed = self._process_single_transaction(
|
||||
account_id, account_info, transaction, "booked"
|
||||
)
|
||||
transactions.append(processed)
|
||||
|
||||
# Process pending transactions
|
||||
for transaction in transaction_data.get("transactions", {}).get("pending", []):
|
||||
processed = self._process_single_transaction(
|
||||
account_id, account_info, transaction, "pending"
|
||||
)
|
||||
transactions.append(processed)
|
||||
|
||||
return transactions
|
||||
|
||||
def _process_single_transaction(
|
||||
self,
|
||||
account_id: str,
|
||||
account_info: Dict[str, Any],
|
||||
transaction: Dict[str, Any],
|
||||
status: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Process a single transaction into standardized format"""
|
||||
# Extract dates
|
||||
booked_date = transaction.get("bookingDateTime") or transaction.get(
|
||||
"bookingDate"
|
||||
)
|
||||
value_date = transaction.get("valueDateTime") or transaction.get("valueDate")
|
||||
|
||||
if booked_date and value_date:
|
||||
min_date = min(
|
||||
datetime.fromisoformat(booked_date), datetime.fromisoformat(value_date)
|
||||
)
|
||||
else:
|
||||
date_str = booked_date or value_date
|
||||
if not date_str:
|
||||
raise ValueError("No valid date found in transaction")
|
||||
min_date = datetime.fromisoformat(date_str)
|
||||
|
||||
# Extract amount and currency
|
||||
transaction_amount = transaction.get("transactionAmount", {})
|
||||
amount = float(transaction_amount.get("amount", 0))
|
||||
currency = transaction_amount.get("currency", "")
|
||||
|
||||
# Extract description
|
||||
description = transaction.get(
|
||||
"remittanceInformationUnstructured",
|
||||
",".join(transaction.get("remittanceInformationUnstructuredArray", [])),
|
||||
)
|
||||
|
||||
# Extract transaction IDs - transactionId is now primary, internalTransactionId is reference
|
||||
transaction_id = transaction.get("transactionId")
|
||||
internal_transaction_id = transaction.get("internalTransactionId")
|
||||
|
||||
if not transaction_id:
|
||||
raise ValueError("Transaction missing required transactionId field")
|
||||
|
||||
return {
|
||||
"accountId": account_id,
|
||||
"transactionId": transaction_id,
|
||||
"internalTransactionId": internal_transaction_id,
|
||||
"institutionId": account_info["institution_id"],
|
||||
"iban": account_info.get("iban", "N/A"),
|
||||
"transactionDate": min_date,
|
||||
"description": description,
|
||||
"transactionValue": amount,
|
||||
"transactionCurrency": currency,
|
||||
"transactionStatus": status,
|
||||
"rawTransaction": transaction,
|
||||
}
|
||||
|
||||
async def get_transactions_from_db(
|
||||
self,
|
||||
account_id: Optional[str] = None,
|
||||
limit: Optional[int] = None, # None means no limit, used for stats
|
||||
offset: Optional[int] = 0,
|
||||
date_from: Optional[str] = None,
|
||||
date_to: Optional[str] = None,
|
||||
min_amount: Optional[float] = None,
|
||||
max_amount: Optional[float] = None,
|
||||
search: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get transactions from SQLite database"""
|
||||
if not self.sqlite_enabled:
|
||||
logger.warning("SQLite database disabled, cannot read transactions")
|
||||
return []
|
||||
|
||||
try:
|
||||
transactions = sqlite_db.get_transactions(
|
||||
account_id=account_id,
|
||||
limit=limit, # Pass limit as-is, None means no limit
|
||||
offset=offset or 0,
|
||||
date_from=date_from,
|
||||
date_to=date_to,
|
||||
min_amount=min_amount,
|
||||
max_amount=max_amount,
|
||||
search=search,
|
||||
)
|
||||
logger.debug(f"Retrieved {len(transactions)} transactions from database")
|
||||
return transactions
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get transactions from database: {e}")
|
||||
return []
|
||||
|
||||
async def get_transaction_count_from_db(
|
||||
self,
|
||||
account_id: Optional[str] = None,
|
||||
date_from: Optional[str] = None,
|
||||
date_to: Optional[str] = None,
|
||||
min_amount: Optional[float] = None,
|
||||
max_amount: Optional[float] = None,
|
||||
search: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Get total count of transactions from SQLite database"""
|
||||
if not self.sqlite_enabled:
|
||||
return 0
|
||||
|
||||
try:
|
||||
filters = {
|
||||
"date_from": date_from,
|
||||
"date_to": date_to,
|
||||
"min_amount": min_amount,
|
||||
"max_amount": max_amount,
|
||||
"search": search,
|
||||
}
|
||||
# Remove None values
|
||||
filters = {k: v for k, v in filters.items() if v is not None}
|
||||
|
||||
count = sqlite_db.get_transaction_count(account_id=account_id, **filters)
|
||||
logger.debug(f"Total transaction count: {count}")
|
||||
return count
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get transaction count from database: {e}")
|
||||
return 0
|
||||
|
||||
async def get_balances_from_db(
|
||||
self, account_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get balances from SQLite database"""
|
||||
if not self.sqlite_enabled:
|
||||
logger.warning("SQLite database disabled, cannot read balances")
|
||||
return []
|
||||
|
||||
try:
|
||||
balances = sqlite_db.get_balances(account_id=account_id)
|
||||
logger.debug(f"Retrieved {len(balances)} balances from database")
|
||||
return balances
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get balances from database: {e}")
|
||||
return []
|
||||
|
||||
async def get_historical_balances_from_db(
|
||||
self, account_id: Optional[str] = None, days: int = 365
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get historical balance progression from SQLite database"""
|
||||
if not self.sqlite_enabled:
|
||||
logger.warning("SQLite database disabled, cannot read historical balances")
|
||||
return []
|
||||
|
||||
try:
|
||||
balances = sqlite_db.get_historical_balances(
|
||||
account_id=account_id, days=days
|
||||
)
|
||||
logger.debug(
|
||||
f"Retrieved {len(balances)} historical balance points from database"
|
||||
)
|
||||
return balances
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get historical balances from database: {e}")
|
||||
return []
|
||||
|
||||
async def get_account_summary_from_db(
|
||||
self, account_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get basic account info from SQLite database (avoids GoCardless call)"""
|
||||
if not self.sqlite_enabled:
|
||||
return None
|
||||
|
||||
try:
|
||||
summary = sqlite_db.get_account_summary(account_id)
|
||||
if summary:
|
||||
logger.debug(
|
||||
f"Retrieved account summary from database for {account_id}"
|
||||
)
|
||||
return summary
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get account summary from database: {e}")
|
||||
return None
|
||||
|
||||
async def persist_account_details(self, account_data: Dict[str, Any]) -> None:
|
||||
"""Persist account details to database"""
|
||||
if not self.sqlite_enabled:
|
||||
logger.warning("SQLite database disabled, skipping account persistence")
|
||||
return
|
||||
|
||||
await self._persist_account_details_sqlite(account_data)
|
||||
|
||||
async def get_accounts_from_db(
|
||||
self, account_ids: Optional[List[str]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get account details from database"""
|
||||
if not self.sqlite_enabled:
|
||||
logger.warning("SQLite database disabled, cannot read accounts")
|
||||
return []
|
||||
|
||||
try:
|
||||
accounts = sqlite_db.get_accounts(account_ids=account_ids)
|
||||
logger.debug(f"Retrieved {len(accounts)} accounts from database")
|
||||
return accounts
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get accounts from database: {e}")
|
||||
return []
|
||||
|
||||
async def get_account_details_from_db(
|
||||
self, account_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get specific account details from database"""
|
||||
if not self.sqlite_enabled:
|
||||
logger.warning("SQLite database disabled, cannot read account")
|
||||
return None
|
||||
|
||||
try:
|
||||
account = sqlite_db.get_account(account_id)
|
||||
if account:
|
||||
logger.debug(
|
||||
f"Retrieved account details from database for {account_id}"
|
||||
)
|
||||
return account
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get account details from database: {e}")
|
||||
return None
|
||||
|
||||
async def run_migrations_if_needed(self):
|
||||
"""Run all necessary database migrations"""
|
||||
if not self.sqlite_enabled:
|
||||
logger.info("SQLite database disabled, skipping migrations")
|
||||
return
|
||||
|
||||
await self._migrate_balance_timestamps_if_needed()
|
||||
await self._migrate_null_transaction_ids_if_needed()
|
||||
await self._migrate_to_composite_key_if_needed()
|
||||
|
||||
async def _migrate_balance_timestamps_if_needed(self):
|
||||
"""Check and migrate balance timestamps if needed"""
|
||||
try:
|
||||
if await self._check_balance_timestamp_migration_needed():
|
||||
logger.info("Balance timestamp migration needed, starting...")
|
||||
await self._migrate_balance_timestamps()
|
||||
logger.info("Balance timestamp migration completed")
|
||||
else:
|
||||
logger.info("Balance timestamps are already consistent")
|
||||
except Exception as e:
|
||||
logger.error(f"Balance timestamp migration failed: {e}")
|
||||
raise
|
||||
|
||||
async def _check_balance_timestamp_migration_needed(self) -> bool:
|
||||
"""Check if balance timestamps need migration"""
|
||||
db_path = path_manager.get_database_path()
|
||||
if not db_path.exists():
|
||||
return False
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check for mixed timestamp types
|
||||
cursor.execute("""
|
||||
SELECT typeof(timestamp) as type, COUNT(*) as count
|
||||
FROM balances
|
||||
GROUP BY typeof(timestamp)
|
||||
""")
|
||||
|
||||
types = cursor.fetchall()
|
||||
conn.close()
|
||||
|
||||
# If we have both 'real' and 'text' types, migration is needed
|
||||
type_names = [row[0] for row in types]
|
||||
return "real" in type_names and "text" in type_names
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check migration status: {e}")
|
||||
return False
|
||||
|
||||
async def _migrate_balance_timestamps(self):
|
||||
"""Convert all Unix timestamps to datetime strings"""
|
||||
db_path = path_manager.get_database_path()
|
||||
if not db_path.exists():
|
||||
logger.warning("Database file not found, skipping migration")
|
||||
return
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get all balances with REAL timestamps
|
||||
cursor.execute("""
|
||||
SELECT id, timestamp
|
||||
FROM balances
|
||||
WHERE typeof(timestamp) = 'real'
|
||||
ORDER BY id
|
||||
""")
|
||||
|
||||
unix_records = cursor.fetchall()
|
||||
total_records = len(unix_records)
|
||||
|
||||
if total_records == 0:
|
||||
logger.info("No Unix timestamps found to migrate")
|
||||
conn.close()
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Migrating {total_records} balance records from Unix to datetime format"
|
||||
)
|
||||
|
||||
# Convert and update in batches
|
||||
batch_size = 100
|
||||
migrated_count = 0
|
||||
|
||||
for i in range(0, total_records, batch_size):
|
||||
batch = unix_records[i : i + batch_size]
|
||||
|
||||
for record_id, unix_timestamp in batch:
|
||||
try:
|
||||
# Convert Unix timestamp to datetime string
|
||||
dt_string = self._unix_to_datetime_string(float(unix_timestamp))
|
||||
|
||||
# Update the record
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE balances
|
||||
SET timestamp = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(dt_string, record_id),
|
||||
)
|
||||
|
||||
migrated_count += 1
|
||||
|
||||
if migrated_count % 100 == 0:
|
||||
logger.info(
|
||||
f"Migrated {migrated_count}/{total_records} balance records"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to migrate record {record_id}: {e}")
|
||||
continue
|
||||
|
||||
# Commit batch
|
||||
conn.commit()
|
||||
|
||||
conn.close()
|
||||
logger.info(f"Successfully migrated {migrated_count} balance records")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Balance timestamp migration failed: {e}")
|
||||
raise
|
||||
|
||||
async def _migrate_null_transaction_ids_if_needed(self):
|
||||
"""Check and migrate null transaction IDs if needed"""
|
||||
try:
|
||||
if await self._check_null_transaction_ids_migration_needed():
|
||||
logger.info("Null transaction IDs migration needed, starting...")
|
||||
await self._migrate_null_transaction_ids()
|
||||
logger.info("Null transaction IDs migration completed")
|
||||
else:
|
||||
logger.info("No null transaction IDs found to migrate")
|
||||
except Exception as e:
|
||||
logger.error(f"Null transaction IDs migration failed: {e}")
|
||||
raise
|
||||
|
||||
async def _check_null_transaction_ids_migration_needed(self) -> bool:
|
||||
"""Check if null transaction IDs need migration"""
|
||||
db_path = path_manager.get_database_path()
|
||||
if not db_path.exists():
|
||||
return False
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check for transactions with null or empty internalTransactionId
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*)
|
||||
FROM transactions
|
||||
WHERE (internalTransactionId IS NULL OR internalTransactionId = '')
|
||||
AND json_extract(rawTransaction, '$.transactionId') IS NOT NULL
|
||||
""")
|
||||
|
||||
count = cursor.fetchone()[0]
|
||||
conn.close()
|
||||
|
||||
return count > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check null transaction IDs migration status: {e}")
|
||||
return False
|
||||
|
||||
async def _migrate_null_transaction_ids(self):
|
||||
"""Populate null internalTransactionId fields using transactionId from raw data"""
|
||||
import uuid
|
||||
|
||||
db_path = path_manager.get_database_path()
|
||||
if not db_path.exists():
|
||||
logger.warning("Database file not found, skipping migration")
|
||||
return
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get all transactions with null/empty internalTransactionId but valid transactionId in raw data
|
||||
cursor.execute("""
|
||||
SELECT rowid, json_extract(rawTransaction, '$.transactionId') as transactionId
|
||||
FROM transactions
|
||||
WHERE (internalTransactionId IS NULL OR internalTransactionId = '')
|
||||
AND json_extract(rawTransaction, '$.transactionId') IS NOT NULL
|
||||
ORDER BY rowid
|
||||
""")
|
||||
|
||||
null_records = cursor.fetchall()
|
||||
total_records = len(null_records)
|
||||
|
||||
if total_records == 0:
|
||||
logger.info("No null transaction IDs found to migrate")
|
||||
conn.close()
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Migrating {total_records} transaction records with null internalTransactionId"
|
||||
)
|
||||
|
||||
# Update in batches
|
||||
batch_size = 100
|
||||
migrated_count = 0
|
||||
skipped_duplicates = 0
|
||||
|
||||
for i in range(0, total_records, batch_size):
|
||||
batch = null_records[i : i + batch_size]
|
||||
|
||||
for rowid, transaction_id in batch:
|
||||
try:
|
||||
# Check if this transactionId is already used by another record
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) FROM transactions WHERE internalTransactionId = ?",
|
||||
(str(transaction_id),),
|
||||
)
|
||||
existing_count = cursor.fetchone()[0]
|
||||
|
||||
if existing_count > 0:
|
||||
# Generate a unique ID to avoid constraint violation
|
||||
unique_id = f"{str(transaction_id)}_{uuid.uuid4().hex[:8]}"
|
||||
logger.debug(
|
||||
f"Generated unique ID for duplicate transactionId: {unique_id}"
|
||||
)
|
||||
else:
|
||||
# Use the original transactionId
|
||||
unique_id = str(transaction_id)
|
||||
|
||||
# Update the record
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE transactions
|
||||
SET internalTransactionId = ?
|
||||
WHERE rowid = ?
|
||||
""",
|
||||
(unique_id, rowid),
|
||||
)
|
||||
|
||||
migrated_count += 1
|
||||
|
||||
if migrated_count % 100 == 0:
|
||||
logger.info(
|
||||
f"Migrated {migrated_count}/{total_records} transaction records"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to migrate record {rowid}: {e}")
|
||||
continue
|
||||
|
||||
# Commit batch
|
||||
conn.commit()
|
||||
|
||||
conn.close()
|
||||
logger.info(f"Successfully migrated {migrated_count} transaction records")
|
||||
if skipped_duplicates > 0:
|
||||
logger.info(
|
||||
f"Generated unique IDs for {skipped_duplicates} duplicate transactionIds"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Null transaction IDs migration failed: {e}")
|
||||
raise
|
||||
|
||||
async def _migrate_to_composite_key_if_needed(self):
|
||||
"""Check and migrate to composite primary key if needed"""
|
||||
try:
|
||||
if await self._check_composite_key_migration_needed():
|
||||
logger.info("Composite key migration needed, starting...")
|
||||
await self._migrate_to_composite_key()
|
||||
logger.info("Composite key migration completed")
|
||||
else:
|
||||
logger.info("Composite key migration not needed")
|
||||
except Exception as e:
|
||||
logger.error(f"Composite key migration failed: {e}")
|
||||
raise
|
||||
|
||||
async def _check_composite_key_migration_needed(self) -> bool:
|
||||
"""Check if composite key migration is needed"""
|
||||
db_path = path_manager.get_database_path()
|
||||
if not db_path.exists():
|
||||
return False
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check if transactions table exists
|
||||
cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='transactions'"
|
||||
)
|
||||
if not cursor.fetchone():
|
||||
conn.close()
|
||||
return False
|
||||
|
||||
# Check if transactions table has the old primary key structure
|
||||
cursor.execute("PRAGMA table_info(transactions)")
|
||||
columns = cursor.fetchall()
|
||||
|
||||
# Check if internalTransactionId is the primary key (old structure)
|
||||
internal_transaction_id_is_pk = any(
|
||||
col[1] == "internalTransactionId" and col[5] == 1 # col[5] is pk flag
|
||||
for col in columns
|
||||
)
|
||||
|
||||
# Check if we have the new composite primary key structure
|
||||
has_composite_key = any(
|
||||
col[1] in ["accountId", "transactionId"]
|
||||
and col[5] == 1 # col[5] is pk flag
|
||||
for col in columns
|
||||
)
|
||||
|
||||
conn.close()
|
||||
|
||||
# Migration is needed if:
|
||||
# 1. internalTransactionId is still the primary key (old structure), OR
|
||||
# 2. We don't have the new composite key structure yet
|
||||
return internal_transaction_id_is_pk or not has_composite_key
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check composite key migration status: {e}")
|
||||
return False
|
||||
|
||||
async def _migrate_to_composite_key(self):
|
||||
"""Migrate transactions table to use composite primary key (accountId, transactionId)"""
|
||||
db_path = path_manager.get_database_path()
|
||||
if not db_path.exists():
|
||||
logger.warning("Database file not found, skipping migration")
|
||||
return
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
logger.info("Starting composite key migration...")
|
||||
|
||||
# Step 1: Create temporary table with new schema
|
||||
logger.info("Creating temporary table with composite primary key...")
|
||||
cursor.execute("DROP TABLE IF EXISTS transactions_temp")
|
||||
cursor.execute("""
|
||||
CREATE TABLE transactions_temp (
|
||||
accountId TEXT NOT NULL,
|
||||
transactionId TEXT NOT NULL,
|
||||
internalTransactionId TEXT,
|
||||
institutionId TEXT,
|
||||
iban TEXT,
|
||||
transactionDate DATETIME,
|
||||
description TEXT,
|
||||
transactionValue REAL,
|
||||
transactionCurrency TEXT,
|
||||
transactionStatus TEXT,
|
||||
rawTransaction JSON,
|
||||
PRIMARY KEY (accountId, transactionId)
|
||||
)
|
||||
""")
|
||||
|
||||
# Step 2: Insert deduplicated data (keep most recent duplicate)
|
||||
logger.info("Inserting deduplicated data...")
|
||||
cursor.execute("""
|
||||
INSERT INTO transactions_temp
|
||||
SELECT
|
||||
accountId,
|
||||
json_extract(rawTransaction, '$.transactionId') as transactionId,
|
||||
internalTransactionId,
|
||||
institutionId,
|
||||
iban,
|
||||
transactionDate,
|
||||
description,
|
||||
transactionValue,
|
||||
transactionCurrency,
|
||||
transactionStatus,
|
||||
rawTransaction
|
||||
FROM (
|
||||
SELECT *,
|
||||
ROW_NUMBER() OVER (
|
||||
PARTITION BY accountId, json_extract(rawTransaction, '$.transactionId')
|
||||
ORDER BY transactionDate DESC, rowid DESC
|
||||
) as rn
|
||||
FROM transactions
|
||||
WHERE json_extract(rawTransaction, '$.transactionId') IS NOT NULL
|
||||
AND accountId IS NOT NULL
|
||||
) WHERE rn = 1
|
||||
""")
|
||||
|
||||
# Get counts for reporting
|
||||
cursor.execute("SELECT COUNT(*) FROM transactions")
|
||||
old_count = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute("SELECT COUNT(*) FROM transactions_temp")
|
||||
new_count = cursor.fetchone()[0]
|
||||
|
||||
duplicates_removed = old_count - new_count
|
||||
logger.info(
|
||||
f"Migration stats: {old_count} → {new_count} records ({duplicates_removed} duplicates removed)"
|
||||
)
|
||||
|
||||
# Step 3: Replace tables
|
||||
logger.info("Replacing tables...")
|
||||
cursor.execute("ALTER TABLE transactions RENAME TO transactions_old")
|
||||
cursor.execute("ALTER TABLE transactions_temp RENAME TO transactions")
|
||||
|
||||
# Step 4: Recreate indexes
|
||||
logger.info("Recreating indexes...")
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_transactions_internal_id ON transactions(internalTransactionId)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_transactions_date ON transactions(transactionDate)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_transactions_account_date ON transactions(accountId, transactionDate)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_transactions_amount ON transactions(transactionValue)"
|
||||
)
|
||||
|
||||
# Step 5: Cleanup
|
||||
logger.info("Cleaning up...")
|
||||
cursor.execute("DROP TABLE transactions_old")
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
logger.info("Composite key migration completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Composite key migration failed: {e}")
|
||||
raise
|
||||
|
||||
def _unix_to_datetime_string(self, unix_timestamp: float) -> str:
|
||||
"""Convert Unix timestamp to datetime string"""
|
||||
dt = datetime.fromtimestamp(unix_timestamp)
|
||||
return dt.isoformat()
|
||||
|
||||
async def _persist_balance_sqlite(
|
||||
self, account_id: str, balance_data: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Persist balance to SQLite"""
|
||||
try:
|
||||
import sqlite3
|
||||
|
||||
db_path = path_manager.get_database_path()
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create the balances table if it doesn't exist
|
||||
cursor.execute(
|
||||
"""CREATE TABLE IF NOT EXISTS balances (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
account_id TEXT,
|
||||
bank TEXT,
|
||||
status TEXT,
|
||||
iban TEXT,
|
||||
amount REAL,
|
||||
currency TEXT,
|
||||
type TEXT,
|
||||
timestamp DATETIME
|
||||
)"""
|
||||
)
|
||||
|
||||
# Create indexes for better performance
|
||||
cursor.execute(
|
||||
"""CREATE INDEX IF NOT EXISTS idx_balances_account_id
|
||||
ON balances(account_id)"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""CREATE INDEX IF NOT EXISTS idx_balances_timestamp
|
||||
ON balances(timestamp)"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""CREATE INDEX IF NOT EXISTS idx_balances_account_type_timestamp
|
||||
ON balances(account_id, type, timestamp)"""
|
||||
)
|
||||
|
||||
# Convert GoCardless balance format to our format and persist
|
||||
for balance in balance_data.get("balances", []):
|
||||
balance_amount = balance["balanceAmount"]
|
||||
|
||||
try:
|
||||
cursor.execute(
|
||||
"""INSERT INTO balances (
|
||||
account_id,
|
||||
bank,
|
||||
status,
|
||||
iban,
|
||||
amount,
|
||||
currency,
|
||||
type,
|
||||
timestamp
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
account_id,
|
||||
balance_data.get("institution_id", "unknown"),
|
||||
balance_data.get("account_status"),
|
||||
balance_data.get("iban", "N/A"),
|
||||
float(balance_amount["amount"]),
|
||||
balance_amount["currency"],
|
||||
balance["balanceType"],
|
||||
datetime.now().isoformat(),
|
||||
),
|
||||
)
|
||||
except sqlite3.IntegrityError:
|
||||
logger.warning(f"Skipped duplicate balance for {account_id}")
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
logger.info(f"Persisted balances to SQLite for account {account_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to persist balances to SQLite: {e}")
|
||||
raise
|
||||
|
||||
async def _persist_transactions_sqlite(
|
||||
self, account_id: str, transactions: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Persist transactions to SQLite"""
|
||||
try:
|
||||
import sqlite3
|
||||
import json
|
||||
|
||||
db_path = path_manager.get_database_path()
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# The table should already exist with the new schema from migration
|
||||
# If it doesn't exist, create it (for new installations)
|
||||
cursor.execute(
|
||||
"""CREATE TABLE IF NOT EXISTS transactions (
|
||||
accountId TEXT NOT NULL,
|
||||
transactionId TEXT NOT NULL,
|
||||
internalTransactionId TEXT,
|
||||
institutionId TEXT,
|
||||
iban TEXT,
|
||||
transactionDate DATETIME,
|
||||
description TEXT,
|
||||
transactionValue REAL,
|
||||
transactionCurrency TEXT,
|
||||
transactionStatus TEXT,
|
||||
rawTransaction JSON,
|
||||
PRIMARY KEY (accountId, transactionId)
|
||||
)"""
|
||||
)
|
||||
|
||||
# Create indexes for better performance (if they don't exist)
|
||||
cursor.execute(
|
||||
"""CREATE INDEX IF NOT EXISTS idx_transactions_internal_id
|
||||
ON transactions(internalTransactionId)"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""CREATE INDEX IF NOT EXISTS idx_transactions_date
|
||||
ON transactions(transactionDate)"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""CREATE INDEX IF NOT EXISTS idx_transactions_account_date
|
||||
ON transactions(accountId, transactionDate)"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""CREATE INDEX IF NOT EXISTS idx_transactions_amount
|
||||
ON transactions(transactionValue)"""
|
||||
)
|
||||
|
||||
# Prepare an SQL statement for inserting/replacing data
|
||||
insert_sql = """INSERT OR REPLACE INTO transactions (
|
||||
accountId,
|
||||
transactionId,
|
||||
internalTransactionId,
|
||||
institutionId,
|
||||
iban,
|
||||
transactionDate,
|
||||
description,
|
||||
transactionValue,
|
||||
transactionCurrency,
|
||||
transactionStatus,
|
||||
rawTransaction
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"""
|
||||
|
||||
new_transactions = []
|
||||
|
||||
for transaction in transactions:
|
||||
try:
|
||||
cursor.execute(
|
||||
insert_sql,
|
||||
(
|
||||
transaction["accountId"],
|
||||
transaction["transactionId"],
|
||||
transaction.get("internalTransactionId"),
|
||||
transaction["institutionId"],
|
||||
transaction["iban"],
|
||||
transaction["transactionDate"],
|
||||
transaction["description"],
|
||||
transaction["transactionValue"],
|
||||
transaction["transactionCurrency"],
|
||||
transaction["transactionStatus"],
|
||||
json.dumps(transaction["rawTransaction"]),
|
||||
),
|
||||
)
|
||||
new_transactions.append(transaction)
|
||||
except sqlite3.IntegrityError as e:
|
||||
logger.warning(
|
||||
f"Failed to insert transaction {transaction.get('transactionId')}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
logger.info(
|
||||
f"Persisted {len(new_transactions)} new transactions to SQLite for account {account_id}"
|
||||
)
|
||||
return new_transactions
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to persist transactions to SQLite: {e}")
|
||||
raise
|
||||
|
||||
async def _persist_account_details_sqlite(
|
||||
self, account_data: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Persist account details to SQLite"""
|
||||
try:
|
||||
# Use the sqlite_db module function
|
||||
sqlite_db.persist_account(account_data)
|
||||
|
||||
logger.info(
|
||||
f"Persisted account details to SQLite for account {account_data['id']}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to persist account details to SQLite: {e}")
|
||||
raise
|
||||
175
leggen/services/gocardless_service.py
Normal file
175
leggen/services/gocardless_service.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import json
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from leggen.utils.config import config
|
||||
from leggen.utils.paths import path_manager
|
||||
|
||||
|
||||
def _log_rate_limits(response):
|
||||
"""Log GoCardless API rate limit headers"""
|
||||
limit = response.headers.get("X-RateLimit-Limit")
|
||||
remaining = response.headers.get("X-RateLimit-Remaining")
|
||||
reset = response.headers.get("X-RateLimit-Reset")
|
||||
account_success_reset = response.headers.get("X-RateLimit-Account-Success-Reset")
|
||||
|
||||
if limit or remaining or reset or account_success_reset:
|
||||
logger.info(
|
||||
f"GoCardless rate limits - Limit: {limit}, Remaining: {remaining}, Reset: {reset}s, Account Success Reset: {account_success_reset}"
|
||||
)
|
||||
|
||||
|
||||
class GoCardlessService:
|
||||
def __init__(self):
|
||||
self.config = config.gocardless_config
|
||||
self.base_url = self.config.get(
|
||||
"url", "https://bankaccountdata.gocardless.com/api/v2"
|
||||
)
|
||||
self._token = None
|
||||
|
||||
async def _get_auth_headers(self) -> Dict[str, str]:
|
||||
"""Get authentication headers for GoCardless API"""
|
||||
token = await self._get_token()
|
||||
return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
||||
|
||||
async def _get_token(self) -> str:
|
||||
"""Get access token for GoCardless API"""
|
||||
if self._token:
|
||||
return self._token
|
||||
|
||||
# Use path manager for auth file
|
||||
auth_file = path_manager.get_auth_file_path()
|
||||
|
||||
if auth_file.exists():
|
||||
try:
|
||||
with open(auth_file, "r") as f:
|
||||
auth = json.load(f)
|
||||
|
||||
if auth.get("access"):
|
||||
# Try to refresh the token
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/token/refresh/",
|
||||
json={"refresh": auth["refresh"]},
|
||||
)
|
||||
_log_rate_limits(response)
|
||||
response.raise_for_status()
|
||||
auth.update(response.json())
|
||||
self._save_auth(auth)
|
||||
self._token = auth["access"]
|
||||
return self._token
|
||||
except httpx.HTTPStatusError:
|
||||
logger.warning("Token refresh failed, creating new token")
|
||||
return await self._create_token()
|
||||
else:
|
||||
return await self._create_token()
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading auth file: {e}")
|
||||
return await self._create_token()
|
||||
else:
|
||||
return await self._create_token()
|
||||
|
||||
async def _create_token(self) -> str:
|
||||
"""Create a new GoCardless access token"""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/token/new/",
|
||||
json={
|
||||
"secret_id": self.config["key"],
|
||||
"secret_key": self.config["secret"],
|
||||
},
|
||||
)
|
||||
_log_rate_limits(response)
|
||||
response.raise_for_status()
|
||||
auth = response.json()
|
||||
self._save_auth(auth)
|
||||
self._token = auth["access"]
|
||||
return self._token
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create GoCardless token: {e}")
|
||||
raise
|
||||
|
||||
def _save_auth(self, auth_data: dict):
|
||||
"""Save authentication data to file"""
|
||||
auth_file = Path.home() / ".config" / "leggen" / "auth.json"
|
||||
auth_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(auth_file, "w") as f:
|
||||
json.dump(auth_data, f)
|
||||
|
||||
async def get_institutions(self, country: str = "PT") -> List[Dict[str, Any]]:
|
||||
"""Get available bank institutions for a country"""
|
||||
headers = await self._get_auth_headers()
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/institutions/",
|
||||
headers=headers,
|
||||
params={"country": country},
|
||||
)
|
||||
_log_rate_limits(response)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def create_requisition(
|
||||
self, institution_id: str, redirect_url: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a bank connection requisition"""
|
||||
headers = await self._get_auth_headers()
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/requisitions/",
|
||||
headers=headers,
|
||||
json={"institution_id": institution_id, "redirect": redirect_url},
|
||||
)
|
||||
_log_rate_limits(response)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def get_requisitions(self) -> Dict[str, Any]:
|
||||
"""Get all requisitions"""
|
||||
headers = await self._get_auth_headers()
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/requisitions/", headers=headers
|
||||
)
|
||||
_log_rate_limits(response)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def get_account_details(self, account_id: str) -> Dict[str, Any]:
|
||||
"""Get account details"""
|
||||
headers = await self._get_auth_headers()
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/accounts/{account_id}/", headers=headers
|
||||
)
|
||||
_log_rate_limits(response)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def get_account_balances(self, account_id: str) -> Dict[str, Any]:
|
||||
"""Get account balances"""
|
||||
headers = await self._get_auth_headers()
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/accounts/{account_id}/balances/", headers=headers
|
||||
)
|
||||
_log_rate_limits(response)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def get_account_transactions(self, account_id: str) -> Dict[str, Any]:
|
||||
"""Get account transactions"""
|
||||
headers = await self._get_auth_headers()
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/accounts/{account_id}/transactions/", headers=headers
|
||||
)
|
||||
_log_rate_limits(response)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
201
leggen/services/notification_service.py
Normal file
201
leggen/services/notification_service.py
Normal file
@@ -0,0 +1,201 @@
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from leggen.utils.config import config
|
||||
|
||||
|
||||
class NotificationService:
|
||||
def __init__(self):
|
||||
self.notifications_config = config.notifications_config
|
||||
self.filters_config = config.filters_config
|
||||
|
||||
async def send_transaction_notifications(
|
||||
self, transactions: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""Send notifications for new transactions that match filters"""
|
||||
if not self.filters_config:
|
||||
logger.info("No notification filters configured, skipping notifications")
|
||||
return
|
||||
|
||||
# Filter transactions that match notification criteria
|
||||
matching_transactions = self._filter_transactions(transactions)
|
||||
|
||||
if not matching_transactions:
|
||||
logger.info("No transactions matched notification filters")
|
||||
return
|
||||
|
||||
# Send to enabled notification services
|
||||
if self._is_discord_enabled():
|
||||
await self._send_discord_notifications(matching_transactions)
|
||||
|
||||
if self._is_telegram_enabled():
|
||||
await self._send_telegram_notifications(matching_transactions)
|
||||
|
||||
async def send_test_notification(self, service: str, message: str) -> bool:
|
||||
"""Send a test notification"""
|
||||
try:
|
||||
if service == "discord" and self._is_discord_enabled():
|
||||
await self._send_discord_test(message)
|
||||
return True
|
||||
elif service == "telegram" and self._is_telegram_enabled():
|
||||
await self._send_telegram_test(message)
|
||||
return True
|
||||
else:
|
||||
logger.error(
|
||||
f"Notification service '{service}' not enabled or not found"
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send test notification to {service}: {e}")
|
||||
return False
|
||||
|
||||
async def send_expiry_notification(self, notification_data: Dict[str, Any]) -> None:
|
||||
"""Send notification about account expiry"""
|
||||
if self._is_discord_enabled():
|
||||
await self._send_discord_expiry(notification_data)
|
||||
|
||||
if self._is_telegram_enabled():
|
||||
await self._send_telegram_expiry(notification_data)
|
||||
|
||||
def _filter_transactions(
|
||||
self, transactions: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Filter transactions based on notification criteria"""
|
||||
matching = []
|
||||
filters_case_insensitive = self.filters_config.get("case-insensitive", [])
|
||||
filters_case_sensitive = self.filters_config.get("case-sensitive", [])
|
||||
|
||||
for transaction in transactions:
|
||||
description = transaction.get("description", "")
|
||||
description_lower = description.lower()
|
||||
|
||||
# Check case-insensitive filters
|
||||
for filter_value in filters_case_insensitive:
|
||||
if filter_value.lower() in description_lower:
|
||||
matching.append(
|
||||
{
|
||||
"name": transaction["description"],
|
||||
"value": transaction["transactionValue"],
|
||||
"currency": transaction["transactionCurrency"],
|
||||
"date": transaction["transactionDate"],
|
||||
}
|
||||
)
|
||||
break
|
||||
|
||||
# Check case-sensitive filters
|
||||
for filter_value in filters_case_sensitive:
|
||||
if filter_value in description:
|
||||
matching.append(
|
||||
{
|
||||
"name": transaction["description"],
|
||||
"value": transaction["transactionValue"],
|
||||
"currency": transaction["transactionCurrency"],
|
||||
"date": transaction["transactionDate"],
|
||||
}
|
||||
)
|
||||
break
|
||||
|
||||
return matching
|
||||
|
||||
def _is_discord_enabled(self) -> bool:
|
||||
"""Check if Discord notifications are enabled"""
|
||||
discord_config = self.notifications_config.get("discord", {})
|
||||
return bool(
|
||||
discord_config.get("webhook") and discord_config.get("enabled", True)
|
||||
)
|
||||
|
||||
def _is_telegram_enabled(self) -> bool:
|
||||
"""Check if Telegram notifications are enabled"""
|
||||
telegram_config = self.notifications_config.get("telegram", {})
|
||||
return bool(
|
||||
telegram_config.get("api-key")
|
||||
and telegram_config.get("chat-id")
|
||||
and telegram_config.get("enabled", True)
|
||||
)
|
||||
|
||||
async def _send_discord_notifications(
|
||||
self, transactions: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""Send Discord notifications - placeholder implementation"""
|
||||
# Would import and use leggen.notifications.discord
|
||||
logger.info(f"Sending {len(transactions)} transaction notifications to Discord")
|
||||
|
||||
async def _send_telegram_notifications(
|
||||
self, transactions: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""Send Telegram notifications - placeholder implementation"""
|
||||
# Would import and use leggen.notifications.telegram
|
||||
logger.info(
|
||||
f"Sending {len(transactions)} transaction notifications to Telegram"
|
||||
)
|
||||
|
||||
async def _send_discord_test(self, message: str) -> None:
|
||||
"""Send Discord test notification"""
|
||||
try:
|
||||
from leggen.notifications.discord import send_expire_notification
|
||||
import click
|
||||
|
||||
# Create a mock context with the webhook
|
||||
ctx = click.Context(click.Command("test"))
|
||||
ctx.obj = {
|
||||
"notifications": {
|
||||
"discord": {
|
||||
"webhook": self.notifications_config.get("discord", {}).get(
|
||||
"webhook"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Send test notification using the actual implementation
|
||||
test_notification = {
|
||||
"bank": "Test",
|
||||
"requisition_id": "test-123",
|
||||
"status": "active",
|
||||
"days_left": 30,
|
||||
}
|
||||
send_expire_notification(ctx, test_notification)
|
||||
logger.info(f"Discord test notification sent: {message}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send Discord test notification: {e}")
|
||||
raise
|
||||
|
||||
async def _send_telegram_test(self, message: str) -> None:
|
||||
"""Send Telegram test notification"""
|
||||
try:
|
||||
from leggen.notifications.telegram import send_expire_notification
|
||||
import click
|
||||
|
||||
# Create a mock context with the telegram config
|
||||
ctx = click.Context(click.Command("test"))
|
||||
telegram_config = self.notifications_config.get("telegram", {})
|
||||
ctx.obj = {
|
||||
"notifications": {
|
||||
"telegram": {
|
||||
"api-key": telegram_config.get("api-key"),
|
||||
"chat-id": telegram_config.get("chat-id"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Send test notification using the actual implementation
|
||||
test_notification = {
|
||||
"bank": "Test",
|
||||
"requisition_id": "test-123",
|
||||
"status": "active",
|
||||
"days_left": 30,
|
||||
}
|
||||
send_expire_notification(ctx, test_notification)
|
||||
logger.info(f"Telegram test notification sent: {message}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send Telegram test notification: {e}")
|
||||
raise
|
||||
|
||||
async def _send_discord_expiry(self, notification_data: Dict[str, Any]) -> None:
|
||||
"""Send Discord expiry notification"""
|
||||
logger.info(f"Sending Discord expiry notification: {notification_data}")
|
||||
|
||||
async def _send_telegram_expiry(self, notification_data: Dict[str, Any]) -> None:
|
||||
"""Send Telegram expiry notification"""
|
||||
logger.info(f"Sending Telegram expiry notification: {notification_data}")
|
||||
187
leggen/services/sync_service.py
Normal file
187
leggen/services/sync_service.py
Normal file
@@ -0,0 +1,187 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from leggen.api.models.sync import SyncResult, SyncStatus
|
||||
from leggen.services.gocardless_service import GoCardlessService
|
||||
from leggen.services.database_service import DatabaseService
|
||||
from leggen.services.notification_service import NotificationService
|
||||
|
||||
|
||||
class SyncService:
|
||||
def __init__(self):
|
||||
self.gocardless = GoCardlessService()
|
||||
self.database = DatabaseService()
|
||||
self.notifications = NotificationService()
|
||||
self._sync_status = SyncStatus(is_running=False)
|
||||
|
||||
async def get_sync_status(self) -> SyncStatus:
|
||||
"""Get current sync status"""
|
||||
return self._sync_status
|
||||
|
||||
async def sync_all_accounts(self, force: bool = False) -> SyncResult:
|
||||
"""Sync all connected accounts"""
|
||||
if self._sync_status.is_running and not force:
|
||||
raise Exception("Sync is already running")
|
||||
|
||||
start_time = datetime.now()
|
||||
self._sync_status.is_running = True
|
||||
self._sync_status.errors = []
|
||||
|
||||
accounts_processed = 0
|
||||
transactions_added = 0
|
||||
transactions_updated = 0
|
||||
balances_updated = 0
|
||||
errors = []
|
||||
|
||||
try:
|
||||
logger.info("Starting sync of all accounts")
|
||||
|
||||
# Get all requisitions and accounts
|
||||
requisitions = await self.gocardless.get_requisitions()
|
||||
all_accounts = set()
|
||||
|
||||
for req in requisitions.get("results", []):
|
||||
all_accounts.update(req.get("accounts", []))
|
||||
|
||||
self._sync_status.total_accounts = len(all_accounts)
|
||||
|
||||
# Process each account
|
||||
for account_id in all_accounts:
|
||||
try:
|
||||
# Get account details
|
||||
account_details = await self.gocardless.get_account_details(
|
||||
account_id
|
||||
)
|
||||
|
||||
# Get balances to extract currency information
|
||||
balances = await self.gocardless.get_account_balances(account_id)
|
||||
|
||||
# Enrich account details with currency and persist
|
||||
if account_details and balances:
|
||||
enriched_account_details = account_details.copy()
|
||||
|
||||
# Extract currency from first balance
|
||||
balances_list = balances.get("balances", [])
|
||||
if balances_list:
|
||||
first_balance = balances_list[0]
|
||||
balance_amount = first_balance.get("balanceAmount", {})
|
||||
currency = balance_amount.get("currency")
|
||||
if currency:
|
||||
enriched_account_details["currency"] = currency
|
||||
|
||||
# Persist enriched account details to database
|
||||
await self.database.persist_account_details(
|
||||
enriched_account_details
|
||||
)
|
||||
|
||||
# Merge account details into balances data for proper persistence
|
||||
balances_with_account_info = balances.copy()
|
||||
balances_with_account_info["institution_id"] = (
|
||||
enriched_account_details.get("institution_id")
|
||||
)
|
||||
balances_with_account_info["iban"] = (
|
||||
enriched_account_details.get("iban")
|
||||
)
|
||||
balances_with_account_info["account_status"] = (
|
||||
enriched_account_details.get("status")
|
||||
)
|
||||
await self.database.persist_balance(
|
||||
account_id, balances_with_account_info
|
||||
)
|
||||
balances_updated += len(balances.get("balances", []))
|
||||
elif account_details:
|
||||
# Fallback: persist account details without currency if balances failed
|
||||
await self.database.persist_account_details(account_details)
|
||||
|
||||
# Get and save transactions
|
||||
transactions = await self.gocardless.get_account_transactions(
|
||||
account_id
|
||||
)
|
||||
if transactions:
|
||||
processed_transactions = self.database.process_transactions(
|
||||
account_id, account_details, transactions
|
||||
)
|
||||
new_transactions = await self.database.persist_transactions(
|
||||
account_id, processed_transactions
|
||||
)
|
||||
transactions_added += len(new_transactions)
|
||||
|
||||
# Send notifications for new transactions
|
||||
if new_transactions:
|
||||
await self.notifications.send_transaction_notifications(
|
||||
new_transactions
|
||||
)
|
||||
|
||||
accounts_processed += 1
|
||||
self._sync_status.accounts_synced = accounts_processed
|
||||
|
||||
logger.info(f"Synced account {account_id} successfully")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to sync account {account_id}: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(error_msg)
|
||||
|
||||
end_time = datetime.now()
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
|
||||
self._sync_status.last_sync = end_time
|
||||
|
||||
result = SyncResult(
|
||||
success=len(errors) == 0,
|
||||
accounts_processed=accounts_processed,
|
||||
transactions_added=transactions_added,
|
||||
transactions_updated=transactions_updated,
|
||||
balances_updated=balances_updated,
|
||||
duration_seconds=duration,
|
||||
errors=errors,
|
||||
started_at=start_time,
|
||||
completed_at=end_time,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Sync completed: {accounts_processed} accounts, {transactions_added} new transactions"
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Sync failed: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
finally:
|
||||
self._sync_status.is_running = False
|
||||
|
||||
async def sync_specific_accounts(
|
||||
self, account_ids: List[str], force: bool = False
|
||||
) -> SyncResult:
|
||||
"""Sync specific accounts"""
|
||||
if self._sync_status.is_running and not force:
|
||||
raise Exception("Sync is already running")
|
||||
|
||||
# Similar implementation but only for specified accounts
|
||||
# For brevity, implementing a simplified version
|
||||
start_time = datetime.now()
|
||||
self._sync_status.is_running = True
|
||||
|
||||
try:
|
||||
# Process only specified accounts
|
||||
# Implementation would be similar to sync_all_accounts
|
||||
# but filtered to only the specified account_ids
|
||||
|
||||
end_time = datetime.now()
|
||||
return SyncResult(
|
||||
success=True,
|
||||
accounts_processed=len(account_ids),
|
||||
transactions_added=0,
|
||||
transactions_updated=0,
|
||||
balances_updated=0,
|
||||
duration_seconds=(end_time - start_time).total_seconds(),
|
||||
errors=[],
|
||||
started_at=start_time,
|
||||
completed_at=end_time,
|
||||
)
|
||||
finally:
|
||||
self._sync_status.is_running = False
|
||||
@@ -1,9 +1,146 @@
|
||||
import os
|
||||
import sys
|
||||
import tomllib
|
||||
import tomli_w
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import click
|
||||
from loguru import logger
|
||||
|
||||
from leggen.utils.text import error
|
||||
from leggen.utils.paths import path_manager
|
||||
|
||||
|
||||
class Config:
|
||||
_instance = None
|
||||
_config = None
|
||||
_config_path = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def load_config(self, config_path: Optional[str] = None) -> Dict[str, Any]:
|
||||
if self._config is not None:
|
||||
return self._config
|
||||
|
||||
if config_path is None:
|
||||
config_path = os.environ.get("LEGGEN_CONFIG_FILE")
|
||||
if not config_path:
|
||||
config_path = str(path_manager.get_config_file_path())
|
||||
|
||||
self._config_path = config_path
|
||||
|
||||
try:
|
||||
with open(config_path, "rb") as f:
|
||||
self._config = tomllib.load(f)
|
||||
logger.info(f"Configuration loaded from {config_path}")
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Configuration file not found: {config_path}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading configuration: {e}")
|
||||
raise
|
||||
|
||||
return self._config
|
||||
|
||||
def save_config(
|
||||
self,
|
||||
config_data: Optional[Dict[str, Any]] = None,
|
||||
config_path: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Save configuration to TOML file"""
|
||||
if config_data is None:
|
||||
config_data = self._config
|
||||
|
||||
if config_path is None:
|
||||
config_path = self._config_path or os.environ.get("LEGGEN_CONFIG_FILE")
|
||||
if not config_path:
|
||||
config_path = str(path_manager.get_config_file_path())
|
||||
|
||||
if config_path is None:
|
||||
raise ValueError("No config path specified")
|
||||
if config_data is None:
|
||||
raise ValueError("No config data to save")
|
||||
|
||||
# Ensure directory exists
|
||||
Path(config_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
with open(config_path, "wb") as f:
|
||||
tomli_w.dump(config_data, f)
|
||||
|
||||
# Update in-memory config
|
||||
self._config = config_data
|
||||
self._config_path = config_path
|
||||
logger.info(f"Configuration saved to {config_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving configuration: {e}")
|
||||
raise
|
||||
|
||||
def update_config(self, section: str, key: str, value: Any) -> None:
|
||||
"""Update a specific configuration value"""
|
||||
if self._config is None:
|
||||
self.load_config()
|
||||
|
||||
if self._config is None:
|
||||
raise RuntimeError("Failed to load config")
|
||||
|
||||
if section not in self._config:
|
||||
self._config[section] = {}
|
||||
|
||||
self._config[section][key] = value
|
||||
self.save_config()
|
||||
|
||||
def update_section(self, section: str, data: Dict[str, Any]) -> None:
|
||||
"""Update an entire configuration section"""
|
||||
if self._config is None:
|
||||
self.load_config()
|
||||
|
||||
if self._config is None:
|
||||
raise RuntimeError("Failed to load config")
|
||||
|
||||
self._config[section] = data
|
||||
self.save_config()
|
||||
|
||||
@property
|
||||
def config(self) -> Dict[str, Any]:
|
||||
if self._config is None:
|
||||
self.load_config()
|
||||
if self._config is None:
|
||||
raise RuntimeError("Failed to load config")
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def gocardless_config(self) -> Dict[str, str]:
|
||||
return self.config.get("gocardless", {})
|
||||
|
||||
@property
|
||||
def database_config(self) -> Dict[str, Any]:
|
||||
return self.config.get("database", {})
|
||||
|
||||
@property
|
||||
def notifications_config(self) -> Dict[str, Any]:
|
||||
return self.config.get("notifications", {})
|
||||
|
||||
@property
|
||||
def filters_config(self) -> Dict[str, Any]:
|
||||
return self.config.get("filters", {})
|
||||
|
||||
@property
|
||||
def scheduler_config(self) -> Dict[str, Any]:
|
||||
"""Get scheduler configuration with defaults"""
|
||||
default_schedule = {
|
||||
"sync": {
|
||||
"enabled": True,
|
||||
"hour": 3,
|
||||
"minute": 0,
|
||||
"cron": None, # Optional custom cron expression
|
||||
}
|
||||
}
|
||||
return self.config.get("scheduler", default_schedule)
|
||||
|
||||
|
||||
def load_config(ctx: click.Context, _, filename):
|
||||
@@ -16,3 +153,7 @@ def load_config(ctx: click.Context, _, filename):
|
||||
"Configuration file not found. Provide a valid configuration file path with leggen --config <path> or LEGGEN_CONFIG=<path> environment variable."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
config = Config()
|
||||
|
||||
Reference in New Issue
Block a user