feat(analytics): Fix transaction limits and improve chart legends

Co-authored-by: elisiariocouto <818914+elisiariocouto@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2025-09-13 19:18:36 +00:00
committed by Elisiário Couto
parent 692bee574e
commit e136fc4b75
15 changed files with 691 additions and 180 deletions

View File

@@ -8,10 +8,11 @@ import {
ResponsiveContainer, ResponsiveContainer,
Legend, Legend,
} from "recharts"; } from "recharts";
import type { Balance } from "../../types/api"; import type { Balance, Account } from "../../types/api";
interface BalanceChartProps { interface BalanceChartProps {
data: Balance[]; data: Balance[];
accounts: Account[];
className?: string; className?: string;
} }
@@ -26,7 +27,34 @@ interface AggregatedDataPoint {
[key: string]: string | number; [key: string]: string | number;
} }
export default function BalanceChart({ data, className }: BalanceChartProps) { export default function BalanceChart({ data, accounts, className }: BalanceChartProps) {
// Create a lookup map for account info
const accountMap = accounts.reduce((map, account) => {
map[account.id] = account;
return map;
}, {} as Record<string, Account>);
// Helper function to get bank name from institution_id
const getBankName = (institutionId: string): string => {
const bankMapping: Record<string, string> = {
'REVOLUT_REVOLT21': 'Revolut',
'NUBANK_NUPBBR25': 'Nu Pagamentos',
'BANCOBPI_BBPIPTPL': 'Banco BPI',
// Add more mappings as needed
};
return bankMapping[institutionId] || institutionId.split('_')[0];
};
// Helper function to create display name for account
const getAccountDisplayName = (accountId: string): string => {
const account = accountMap[accountId];
if (account) {
const bankName = getBankName(account.institution_id);
const accountName = account.name || `Account ${accountId.split('-')[1]}`;
return `${bankName} - ${accountName}`;
}
return `Account ${accountId.split('-')[1]}`;
};
// Process balance data for the chart // Process balance data for the chart
const chartData = data const chartData = data
.filter((balance) => balance.balance_type === "closingBooked") .filter((balance) => balance.balance_type === "closingBooked")
@@ -116,7 +144,7 @@ export default function BalanceChart({ data, className }: BalanceChartProps) {
stroke={colors[index % colors.length]} stroke={colors[index % colors.length]}
strokeWidth={2} strokeWidth={2}
dot={{ r: 4 }} dot={{ r: 4 }}
name={`Account ${accountId.split('-')[1]}`} name={getAccountDisplayName(accountId)}
/> />
))} ))}
</LineChart> </LineChart>

View File

@@ -32,18 +32,12 @@ interface TooltipProps {
} }
export default function MonthlyTrends({ className }: MonthlyTrendsProps) { export default function MonthlyTrends({ className }: MonthlyTrendsProps) {
// Get transactions for the last 12 months // Get transactions for the last 12 months using analytics endpoint
const { data: transactions, isLoading } = useQuery({ const { data: transactions, isLoading } = useQuery({
queryKey: ["transactions", "monthly-trends"], queryKey: ["transactions", "monthly-trends"],
queryFn: async () => { queryFn: async () => {
const response = await apiClient.getTransactions({ // Get last 365 days of transactions for monthly trends
startDate: new Date( return await apiClient.getTransactionsForAnalytics(365);
Date.now() - 365 * 24 * 60 * 60 * 1000
).toISOString().split("T")[0],
endDate: new Date().toISOString().split("T")[0],
perPage: 1000,
});
return response.data;
}, },
}); });
@@ -54,7 +48,7 @@ export default function MonthlyTrends({ className }: MonthlyTrendsProps) {
const monthlyMap: { [key: string]: MonthlyData } = {}; const monthlyMap: { [key: string]: MonthlyData } = {};
transactions.forEach((transaction) => { transactions.forEach((transaction) => {
const date = new Date(transaction.transaction_date); const date = new Date(transaction.date);
const monthKey = `${date.getFullYear()}-${String(date.getMonth() + 1).padStart(2, '0')}`; const monthKey = `${date.getFullYear()}-${String(date.getMonth() + 1).padStart(2, '0')}`;
if (!monthlyMap[monthKey]) { if (!monthlyMap[monthKey]) {
@@ -69,10 +63,10 @@ export default function MonthlyTrends({ className }: MonthlyTrendsProps) {
}; };
} }
if (transaction.transaction_value > 0) { if (transaction.amount > 0) {
monthlyMap[monthKey].income += transaction.transaction_value; monthlyMap[monthKey].income += transaction.amount;
} else { } else {
monthlyMap[monthKey].expenses += Math.abs(transaction.transaction_value); monthlyMap[monthKey].expenses += Math.abs(transaction.amount);
} }
monthlyMap[monthKey].net = monthlyMap[monthKey].income - monthlyMap[monthKey].expenses; monthlyMap[monthKey].net = monthlyMap[monthKey].income - monthlyMap[monthKey].expenses;

View File

@@ -30,6 +30,24 @@ export default function TransactionDistribution({
accounts, accounts,
className, className,
}: TransactionDistributionProps) { }: TransactionDistributionProps) {
// Helper function to get bank name from institution_id
const getBankName = (institutionId: string): string => {
const bankMapping: Record<string, string> = {
'REVOLUT_REVOLT21': 'Revolut',
'NUBANK_NUPBBR25': 'Nu Pagamentos',
'BANCOBPI_BBPIPTPL': 'Banco BPI',
// Add more mappings as needed
};
return bankMapping[institutionId] || institutionId.split('_')[0];
};
// Helper function to create display name for account
const getAccountDisplayName = (account: Account): string => {
const bankName = getBankName(account.institution_id);
const accountName = account.name || `Account ${account.id.split('-')[1]}`;
return `${bankName} - ${accountName}`;
};
// Create pie chart data from account balances // Create pie chart data from account balances
const pieData: PieDataPoint[] = accounts.map((account, index) => { const pieData: PieDataPoint[] = accounts.map((account, index) => {
const closingBalance = account.balances.find( const closingBalance = account.balances.find(
@@ -39,7 +57,7 @@ export default function TransactionDistribution({
const colors = ["#3B82F6", "#10B981", "#F59E0B", "#EF4444", "#8B5CF6"]; const colors = ["#3B82F6", "#10B981", "#F59E0B", "#EF4444", "#8B5CF6"];
return { return {
name: account.name || `Account ${account.id.split('-')[1]}`, name: getAccountDisplayName(account),
value: closingBalance?.amount || 0, value: closingBalance?.amount || 0,
color: colors[index % colors.length], color: colors[index % colors.length],
}; };

View File

@@ -154,6 +154,17 @@ export const apiClient = {
); );
return response.data.data; return response.data.data;
}, },
// Get all transactions for analytics (no pagination)
getTransactionsForAnalytics: async (days?: number): Promise<Transaction[]> => {
const queryParams = new URLSearchParams();
if (days) queryParams.append("days", days.toString());
const response = await api.get<ApiResponse<Transaction[]>>(
`/transactions/analytics?${queryParams.toString()}`
);
return response.data.data;
},
}; };
export default apiClient; export default apiClient;

View File

@@ -126,7 +126,7 @@ function AnalyticsDashboard() {
{/* Charts */} {/* Charts */}
<div className="grid grid-cols-1 lg:grid-cols-2 gap-8"> <div className="grid grid-cols-1 lg:grid-cols-2 gap-8">
<div className="bg-white rounded-lg shadow p-6 border border-gray-200"> <div className="bg-white rounded-lg shadow p-6 border border-gray-200">
<BalanceChart data={balances || []} /> <BalanceChart data={balances || []} accounts={accounts || []} />
</div> </div>
<div className="bg-white rounded-lg shadow p-6 border border-gray-200"> <div className="bg-white rounded-lg shadow p-6 border border-gray-200">
<TransactionDistribution accounts={accounts || []} /> <TransactionDistribution accounts={accounts || []} />

View File

@@ -30,29 +30,33 @@ from leggen.utils.paths import path_manager
help="Overwrite existing database without confirmation", help="Overwrite existing database without confirmation",
) )
@click.pass_context @click.pass_context
def generate_sample_db(ctx: click.Context, database: Path, accounts: int, transactions: int, force: bool): def generate_sample_db(
ctx: click.Context, database: Path, accounts: int, transactions: int, force: bool
):
"""Generate a sample database with realistic financial data for testing.""" """Generate a sample database with realistic financial data for testing."""
# Import here to avoid circular imports # Import here to avoid circular imports
import sys import sys
import subprocess import subprocess
from pathlib import Path as PathlibPath from pathlib import Path as PathlibPath
# Get the script path # Get the script path
script_path = PathlibPath(__file__).parent.parent.parent / "scripts" / "generate_sample_db.py" script_path = (
PathlibPath(__file__).parent.parent.parent / "scripts" / "generate_sample_db.py"
)
# Build command arguments # Build command arguments
cmd = [sys.executable, str(script_path)] cmd = [sys.executable, str(script_path)]
if database: if database:
cmd.extend(["--database", str(database)]) cmd.extend(["--database", str(database)])
cmd.extend(["--accounts", str(accounts)]) cmd.extend(["--accounts", str(accounts)])
cmd.extend(["--transactions", str(transactions)]) cmd.extend(["--transactions", str(transactions)])
if force: if force:
cmd.append("--force") cmd.append("--force")
# Execute the script # Execute the script
try: try:
subprocess.run(cmd, check=True) subprocess.run(cmd, check=True)
@@ -62,4 +66,4 @@ def generate_sample_db(ctx: click.Context, database: Path, accounts: int, transa
# Export the command # Export the command
generate_sample_db = generate_sample_db generate_sample_db = generate_sample_db

View File

@@ -7,32 +7,32 @@ from typing import Optional
class PathManager: class PathManager:
"""Manages configurable paths for config and database files.""" """Manages configurable paths for config and database files."""
def __init__(self): def __init__(self):
self._config_dir: Optional[Path] = None self._config_dir: Optional[Path] = None
self._database_path: Optional[Path] = None self._database_path: Optional[Path] = None
def get_config_dir(self) -> Path: def get_config_dir(self) -> Path:
"""Get the configuration directory.""" """Get the configuration directory."""
if self._config_dir is not None: if self._config_dir is not None:
return self._config_dir return self._config_dir
# Check environment variable first # Check environment variable first
config_dir = os.environ.get("LEGGEN_CONFIG_DIR") config_dir = os.environ.get("LEGGEN_CONFIG_DIR")
if config_dir: if config_dir:
return Path(config_dir) return Path(config_dir)
# Default to ~/.config/leggen # Default to ~/.config/leggen
return Path.home() / ".config" / "leggen" return Path.home() / ".config" / "leggen"
def set_config_dir(self, path: Path) -> None: def set_config_dir(self, path: Path) -> None:
"""Set the configuration directory.""" """Set the configuration directory."""
self._config_dir = Path(path) self._config_dir = Path(path)
def get_config_file_path(self) -> Path: def get_config_file_path(self) -> Path:
"""Get the configuration file path.""" """Get the configuration file path."""
return self.get_config_dir() / "config.toml" return self.get_config_dir() / "config.toml"
def get_database_path(self) -> Path: def get_database_path(self) -> Path:
"""Get the database file path and ensure the directory exists.""" """Get the database file path and ensure the directory exists."""
if self._database_path is not None: if self._database_path is not None:
@@ -45,7 +45,7 @@ class PathManager:
else: else:
# Default to config_dir/leggen.db # Default to config_dir/leggen.db
db_path = self.get_config_dir() / "leggen.db" db_path = self.get_config_dir() / "leggen.db"
# Try to ensure the directory exists, but handle permission errors gracefully # Try to ensure the directory exists, but handle permission errors gracefully
try: try:
db_path.parent.mkdir(parents=True, exist_ok=True) db_path.parent.mkdir(parents=True, exist_ok=True)
@@ -53,24 +53,24 @@ class PathManager:
# If we can't create the directory, continue anyway # If we can't create the directory, continue anyway
# This allows tests and error cases to work as expected # This allows tests and error cases to work as expected
pass pass
return db_path return db_path
def set_database_path(self, path: Path) -> None: def set_database_path(self, path: Path) -> None:
"""Set the database file path.""" """Set the database file path."""
self._database_path = Path(path) self._database_path = Path(path)
def get_auth_file_path(self) -> Path: def get_auth_file_path(self) -> Path:
"""Get the authentication file path.""" """Get the authentication file path."""
return self.get_config_dir() / "auth.json" return self.get_config_dir() / "auth.json"
def ensure_config_dir_exists(self) -> None: def ensure_config_dir_exists(self) -> None:
"""Ensure the configuration directory exists.""" """Ensure the configuration directory exists."""
self.get_config_dir().mkdir(parents=True, exist_ok=True) self.get_config_dir().mkdir(parents=True, exist_ok=True)
def ensure_database_dir_exists(self) -> None: def ensure_database_dir_exists(self) -> None:
"""Ensure the database directory exists. """Ensure the database directory exists.
Note: get_database_path() now automatically ensures the directory exists, Note: get_database_path() now automatically ensures the directory exists,
so this method is mainly for explicit directory creation in tests. so this method is mainly for explicit directory creation in tests.
""" """
@@ -78,4 +78,4 @@ class PathManager:
# Global instance for the application # Global instance for the application
path_manager = PathManager() path_manager = PathManager()

View File

@@ -121,6 +121,219 @@ async def get_all_transactions(
) from e ) from e
@router.get("/transactions/enhanced-stats", response_model=APIResponse)
async def get_enhanced_transaction_stats(
days: int = Query(default=365, description="Number of days to include in stats"),
account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
) -> APIResponse:
"""Get enhanced transaction statistics with monthly breakdown and account details"""
try:
# Date range for stats
end_date = datetime.now()
start_date = end_date - timedelta(days=days)
# Format dates for database query
date_from = start_date.isoformat()
date_to = end_date.isoformat()
# Get all transactions from database for comprehensive stats
recent_transactions = await database_service.get_transactions_from_db(
account_id=account_id,
date_from=date_from,
date_to=date_to,
limit=None, # Get all matching transactions
)
# Basic stats
total_transactions = len(recent_transactions)
total_income = sum(
txn["transactionValue"]
for txn in recent_transactions
if txn["transactionValue"] > 0
)
total_expenses = sum(
abs(txn["transactionValue"])
for txn in recent_transactions
if txn["transactionValue"] < 0
)
net_change = total_income - total_expenses
# Count by status
booked_count = len(
[txn for txn in recent_transactions if txn["transactionStatus"] == "booked"]
)
pending_count = len(
[
txn
for txn in recent_transactions
if txn["transactionStatus"] == "pending"
]
)
# Count unique accounts
unique_accounts = len({txn["accountId"] for txn in recent_transactions})
# Monthly breakdown
monthly_stats = {}
for txn in recent_transactions:
try:
txn_date = datetime.fromisoformat(
txn["transactionDate"].replace("Z", "+00:00")
)
month_key = txn_date.strftime("%Y-%m")
if month_key not in monthly_stats:
monthly_stats[month_key] = {
"month": txn_date.strftime("%Y %b"),
"income": 0,
"expenses": 0,
"net": 0,
"transaction_count": 0,
}
monthly_stats[month_key]["transaction_count"] += 1
if txn["transactionValue"] > 0:
monthly_stats[month_key]["income"] += txn["transactionValue"]
else:
monthly_stats[month_key]["expenses"] += abs(txn["transactionValue"])
monthly_stats[month_key]["net"] = (
monthly_stats[month_key]["income"]
- monthly_stats[month_key]["expenses"]
)
except (ValueError, TypeError):
# Skip transactions with invalid dates
continue
# Account breakdown
account_stats = {}
for txn in recent_transactions:
acc_id = txn["accountId"]
if acc_id not in account_stats:
account_stats[acc_id] = {
"account_id": acc_id,
"transaction_count": 0,
"income": 0,
"expenses": 0,
"net": 0,
}
account_stats[acc_id]["transaction_count"] += 1
if txn["transactionValue"] > 0:
account_stats[acc_id]["income"] += txn["transactionValue"]
else:
account_stats[acc_id]["expenses"] += abs(txn["transactionValue"])
account_stats[acc_id]["net"] = (
account_stats[acc_id]["income"] - account_stats[acc_id]["expenses"]
)
enhanced_stats = {
"period_days": days,
"date_range": {
"start": start_date.isoformat(),
"end": end_date.isoformat(),
},
"summary": {
"total_transactions": total_transactions,
"booked_transactions": booked_count,
"pending_transactions": pending_count,
"total_income": round(total_income, 2),
"total_expenses": round(total_expenses, 2),
"net_change": round(net_change, 2),
"average_transaction": round(
sum(txn["transactionValue"] for txn in recent_transactions)
/ total_transactions,
2,
)
if total_transactions > 0
else 0,
"accounts_included": unique_accounts,
},
"monthly_breakdown": [
{
**stats,
"income": round(stats["income"], 2),
"expenses": round(stats["expenses"], 2),
"net": round(stats["net"], 2),
}
for month, stats in sorted(monthly_stats.items())
],
"account_breakdown": [
{
**stats,
"income": round(stats["income"], 2),
"expenses": round(stats["expenses"], 2),
"net": round(stats["net"], 2),
}
for stats in account_stats.values()
],
}
return APIResponse(
success=True,
data=enhanced_stats,
message=f"Enhanced transaction statistics for last {days} days",
)
except Exception as e:
logger.error(f"Failed to get enhanced transaction stats: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to get enhanced transaction stats: {str(e)}",
) from e
@router.get("/transactions/analytics", response_model=APIResponse)
async def get_transactions_for_analytics(
days: int = Query(default=365, description="Number of days to include"),
account_id: Optional[str] = Query(default=None, description="Filter by account ID"),
) -> APIResponse:
"""Get all transactions for analytics (no pagination) for the last N days"""
try:
# Date range for analytics
end_date = datetime.now()
start_date = end_date - timedelta(days=days)
# Format dates for database query
date_from = start_date.isoformat()
date_to = end_date.isoformat()
# Get ALL transactions from database (no limit for analytics)
transactions = await database_service.get_transactions_from_db(
account_id=account_id,
date_from=date_from,
date_to=date_to,
limit=None, # No limit - get all transactions
)
# Transform for frontend (summary format)
transaction_summaries = [
{
"transaction_id": txn["transactionId"],
"date": txn["transactionDate"],
"description": txn["description"],
"amount": txn["transactionValue"],
"currency": txn["transactionCurrency"],
"status": txn["transactionStatus"],
"account_id": txn["accountId"],
}
for txn in transactions
]
return APIResponse(
success=True,
data=transaction_summaries,
message=f"Retrieved {len(transaction_summaries)} transactions for analytics",
)
except Exception as e:
logger.error(f"Failed to get transactions for analytics: {e}")
raise HTTPException(
status_code=500, detail=f"Failed to get analytics transactions: {str(e)}"
) from e
@router.get("/transactions/stats", response_model=APIResponse) @router.get("/transactions/stats", response_model=APIResponse)
async def get_transaction_stats( async def get_transaction_stats(
days: int = Query(default=30, description="Number of days to include in stats"), days: int = Query(default=30, description="Number of days to include in stats"),

View File

@@ -23,9 +23,7 @@ class Config:
return self._config return self._config
if config_path is None: if config_path is None:
config_path = os.environ.get( config_path = os.environ.get("LEGGEN_CONFIG_FILE")
"LEGGEN_CONFIG_FILE"
)
if not config_path: if not config_path:
config_path = str(path_manager.get_config_file_path()) config_path = str(path_manager.get_config_file_path())
@@ -54,9 +52,7 @@ class Config:
config_data = self._config config_data = self._config
if config_path is None: if config_path is None:
config_path = self._config_path or os.environ.get( config_path = self._config_path or os.environ.get("LEGGEN_CONFIG_FILE")
"LEGGEN_CONFIG_FILE"
)
if not config_path: if not config_path:
config_path = str(path_manager.get_config_file_path()) config_path = str(path_manager.get_config_file_path())

View File

@@ -118,7 +118,7 @@ class DatabaseService:
async def get_transactions_from_db( async def get_transactions_from_db(
self, self,
account_id: Optional[str] = None, account_id: Optional[str] = None,
limit: Optional[int] = 100, limit: Optional[int] = None, # None means no limit, used for stats
offset: Optional[int] = 0, offset: Optional[int] = 0,
date_from: Optional[str] = None, date_from: Optional[str] = None,
date_to: Optional[str] = None, date_to: Optional[str] = None,
@@ -134,7 +134,7 @@ class DatabaseService:
try: try:
transactions = sqlite_db.get_transactions( transactions = sqlite_db.get_transactions(
account_id=account_id, account_id=account_id,
limit=limit or 100, limit=limit, # Pass limit as-is, None means no limit
offset=offset or 0, offset=offset or 0,
date_from=date_from, date_from=date_from,
date_to=date_to, date_to=date_to,
@@ -424,7 +424,7 @@ class DatabaseService:
async def _migrate_null_transaction_ids(self): async def _migrate_null_transaction_ids(self):
"""Populate null internalTransactionId fields using transactionId from raw data""" """Populate null internalTransactionId fields using transactionId from raw data"""
import uuid import uuid
db_path = path_manager.get_database_path() db_path = path_manager.get_database_path()
if not db_path.exists(): if not db_path.exists():
logger.warning("Database file not found, skipping migration") logger.warning("Database file not found, skipping migration")

View File

@@ -32,7 +32,7 @@ class SampleDataGenerator:
"country": "LT", "country": "LT",
}, },
{ {
"id": "BANCOBPI_BBPIPTPL", "id": "BANCOBPI_BBPIPTPL",
"name": "Banco BPI", "name": "Banco BPI",
"bic": "BBPIPTPL", "bic": "BBPIPTPL",
"country": "PT", "country": "PT",
@@ -40,7 +40,7 @@ class SampleDataGenerator:
{ {
"id": "MONZO_MONZGB2L", "id": "MONZO_MONZGB2L",
"name": "Monzo Bank", "name": "Monzo Bank",
"bic": "MONZGB2L", "bic": "MONZGB2L",
"country": "GB", "country": "GB",
}, },
{ {
@@ -50,16 +50,40 @@ class SampleDataGenerator:
"country": "BR", "country": "BR",
}, },
] ]
self.transaction_types = [ self.transaction_types = [
{"description": "Grocery Store", "amount_range": (-150, -20), "frequency": 0.3}, {
"description": "Grocery Store",
"amount_range": (-150, -20),
"frequency": 0.3,
},
{"description": "Coffee Shop", "amount_range": (-15, -3), "frequency": 0.2}, {"description": "Coffee Shop", "amount_range": (-15, -3), "frequency": 0.2},
{"description": "Gas Station", "amount_range": (-80, -30), "frequency": 0.1}, {
{"description": "Online Shopping", "amount_range": (-200, -25), "frequency": 0.15}, "description": "Gas Station",
{"description": "Restaurant", "amount_range": (-60, -15), "frequency": 0.15}, "amount_range": (-80, -30),
"frequency": 0.1,
},
{
"description": "Online Shopping",
"amount_range": (-200, -25),
"frequency": 0.15,
},
{
"description": "Restaurant",
"amount_range": (-60, -15),
"frequency": 0.15,
},
{"description": "Salary", "amount_range": (2500, 5000), "frequency": 0.02}, {"description": "Salary", "amount_range": (2500, 5000), "frequency": 0.02},
{"description": "ATM Withdrawal", "amount_range": (-200, -20), "frequency": 0.05}, {
{"description": "Transfer to Savings", "amount_range": (-1000, -100), "frequency": 0.03}, "description": "ATM Withdrawal",
"amount_range": (-200, -20),
"frequency": 0.05,
},
{
"description": "Transfer to Savings",
"amount_range": (-1000, -100),
"frequency": 0.03,
},
] ]
def ensure_database_dir(self): def ensure_database_dir(self):
@@ -120,15 +144,33 @@ class SampleDataGenerator:
""") """)
# Create indexes # Create indexes
cursor.execute("CREATE INDEX IF NOT EXISTS idx_transactions_internal_id ON transactions(internalTransactionId)") cursor.execute(
cursor.execute("CREATE INDEX IF NOT EXISTS idx_transactions_date ON transactions(transactionDate)") "CREATE INDEX IF NOT EXISTS idx_transactions_internal_id ON transactions(internalTransactionId)"
cursor.execute("CREATE INDEX IF NOT EXISTS idx_transactions_account_date ON transactions(accountId, transactionDate)") )
cursor.execute("CREATE INDEX IF NOT EXISTS idx_transactions_amount ON transactions(transactionValue)") cursor.execute(
cursor.execute("CREATE INDEX IF NOT EXISTS idx_balances_account_id ON balances(account_id)") "CREATE INDEX IF NOT EXISTS idx_transactions_date ON transactions(transactionDate)"
cursor.execute("CREATE INDEX IF NOT EXISTS idx_balances_timestamp ON balances(timestamp)") )
cursor.execute("CREATE INDEX IF NOT EXISTS idx_balances_account_type_timestamp ON balances(account_id, type, timestamp)") cursor.execute(
cursor.execute("CREATE INDEX IF NOT EXISTS idx_accounts_institution_id ON accounts(institution_id)") "CREATE INDEX IF NOT EXISTS idx_transactions_account_date ON transactions(accountId, transactionDate)"
cursor.execute("CREATE INDEX IF NOT EXISTS idx_accounts_status ON accounts(status)") )
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_transactions_amount ON transactions(transactionValue)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_balances_account_id ON balances(account_id)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_balances_timestamp ON balances(timestamp)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_balances_account_type_timestamp ON balances(account_id, type, timestamp)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_accounts_institution_id ON accounts(institution_id)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_accounts_status ON accounts(status)"
)
conn.commit() conn.commit()
conn.close() conn.close()
@@ -141,78 +183,109 @@ class SampleDataGenerator:
"GB": lambda: f"GB{random.randint(10, 99)}MONZ{random.randint(100000, 999999)}{random.randint(100000, 999999)}", "GB": lambda: f"GB{random.randint(10, 99)}MONZ{random.randint(100000, 999999)}{random.randint(100000, 999999)}",
"BR": lambda: f"BR{random.randint(10, 99)}{random.randint(10000000, 99999999)}{random.randint(1000, 9999)}{random.randint(10000000, 99999999)}", "BR": lambda: f"BR{random.randint(10, 99)}{random.randint(10000000, 99999999)}{random.randint(1000, 9999)}{random.randint(10000000, 99999999)}",
} }
return ibans.get(country_code, lambda: f"{country_code}{random.randint(1000000000000000, 9999999999999999)}")() return ibans.get(
country_code,
lambda: f"{country_code}{random.randint(1000000000000000, 9999999999999999)}",
)()
def generate_accounts(self, num_accounts: int = 3) -> List[Dict[str, Any]]: def generate_accounts(self, num_accounts: int = 3) -> List[Dict[str, Any]]:
"""Generate sample accounts.""" """Generate sample accounts."""
accounts = [] accounts = []
base_date = datetime.now() - timedelta(days=90) base_date = datetime.now() - timedelta(days=90)
for i in range(num_accounts): for i in range(num_accounts):
institution = random.choice(self.institutions) institution = random.choice(self.institutions)
account_id = f"account-{i+1:03d}-{random.randint(1000, 9999)}" account_id = f"account-{i + 1:03d}-{random.randint(1000, 9999)}"
account = { account = {
"id": account_id, "id": account_id,
"institution_id": institution["id"], "institution_id": institution["id"],
"status": "READY", "status": "READY",
"iban": self.generate_iban(institution["country"]), "iban": self.generate_iban(institution["country"]),
"name": f"Personal Account {i+1}", "name": f"Personal Account {i + 1}",
"currency": "EUR", "currency": "EUR",
"created": (base_date + timedelta(days=random.randint(0, 30))).isoformat(), "created": (
"last_accessed": (datetime.now() - timedelta(hours=random.randint(1, 48))).isoformat(), base_date + timedelta(days=random.randint(0, 30))
).isoformat(),
"last_accessed": (
datetime.now() - timedelta(hours=random.randint(1, 48))
).isoformat(),
"last_updated": datetime.now().isoformat(), "last_updated": datetime.now().isoformat(),
} }
accounts.append(account) accounts.append(account)
return accounts return accounts
def generate_transactions(self, accounts: List[Dict[str, Any]], num_transactions_per_account: int = 50) -> List[Dict[str, Any]]: def generate_transactions(
self, accounts: List[Dict[str, Any]], num_transactions_per_account: int = 50
) -> List[Dict[str, Any]]:
"""Generate sample transactions for accounts.""" """Generate sample transactions for accounts."""
transactions = [] transactions = []
base_date = datetime.now() - timedelta(days=60) base_date = datetime.now() - timedelta(days=60)
for account in accounts: for account in accounts:
account_transactions = [] account_transactions = []
current_balance = random.uniform(500, 3000) current_balance = random.uniform(500, 3000)
for i in range(num_transactions_per_account): for i in range(num_transactions_per_account):
# Choose transaction type based on frequency weights # Choose transaction type based on frequency weights
transaction_type = random.choices( transaction_type = random.choices(
self.transaction_types, self.transaction_types,
weights=[t["frequency"] for t in self.transaction_types] weights=[t["frequency"] for t in self.transaction_types],
)[0] )[0]
# Generate transaction amount # Generate transaction amount
min_amount, max_amount = transaction_type["amount_range"] min_amount, max_amount = transaction_type["amount_range"]
amount = round(random.uniform(min_amount, max_amount), 2) amount = round(random.uniform(min_amount, max_amount), 2)
# Generate transaction date (more recent transactions are more likely) # Generate transaction date (more recent transactions are more likely)
days_ago = random.choices( days_ago = random.choices(
range(60), range(60), weights=[1.5 ** (60 - d) for d in range(60)]
weights=[1.5 ** (60 - d) for d in range(60)]
)[0] )[0]
transaction_date = base_date + timedelta(days=days_ago, hours=random.randint(6, 22), minutes=random.randint(0, 59)) transaction_date = base_date + timedelta(
days=days_ago,
hours=random.randint(6, 22),
minutes=random.randint(0, 59),
)
# Generate transaction IDs # Generate transaction IDs
transaction_id = f"bank-txn-{account['id']}-{i+1:04d}" transaction_id = f"bank-txn-{account['id']}-{i + 1:04d}"
internal_transaction_id = f"int-txn-{random.randint(100000, 999999)}" internal_transaction_id = f"int-txn-{random.randint(100000, 999999)}"
# Create realistic descriptions # Create realistic descriptions
descriptions = { descriptions = {
"Grocery Store": ["TESCO", "SAINSBURY'S", "LIDL", "ALDI", "WALMART", "CARREFOUR"], "Grocery Store": [
"Coffee Shop": ["STARBUCKS", "COSTA COFFEE", "PRET A MANGER", "LOCAL CAFE"], "TESCO",
"SAINSBURY'S",
"LIDL",
"ALDI",
"WALMART",
"CARREFOUR",
],
"Coffee Shop": [
"STARBUCKS",
"COSTA COFFEE",
"PRET A MANGER",
"LOCAL CAFE",
],
"Gas Station": ["BP", "SHELL", "ESSO", "GALP", "PETROBRAS"], "Gas Station": ["BP", "SHELL", "ESSO", "GALP", "PETROBRAS"],
"Online Shopping": ["AMAZON", "EBAY", "ZALANDO", "ASOS", "APPLE"], "Online Shopping": ["AMAZON", "EBAY", "ZALANDO", "ASOS", "APPLE"],
"Restaurant": ["PIZZA HUT", "MCDONALD'S", "BURGER KING", "LOCAL RESTAURANT"], "Restaurant": [
"PIZZA HUT",
"MCDONALD'S",
"BURGER KING",
"LOCAL RESTAURANT",
],
"Salary": ["MONTHLY SALARY", "PAYROLL DEPOSIT", "SALARY PAYMENT"], "Salary": ["MONTHLY SALARY", "PAYROLL DEPOSIT", "SALARY PAYMENT"],
"ATM Withdrawal": ["ATM WITHDRAWAL", "CASH WITHDRAWAL"], "ATM Withdrawal": ["ATM WITHDRAWAL", "CASH WITHDRAWAL"],
"Transfer to Savings": ["SAVINGS TRANSFER", "INVESTMENT TRANSFER"], "Transfer to Savings": ["SAVINGS TRANSFER", "INVESTMENT TRANSFER"],
} }
specific_descriptions = descriptions.get(transaction_type["description"], [transaction_type["description"]]) specific_descriptions = descriptions.get(
transaction_type["description"], [transaction_type["description"]]
)
description = random.choice(specific_descriptions) description = random.choice(specific_descriptions)
# Create raw transaction (simplified GoCardless format) # Create raw transaction (simplified GoCardless format)
raw_transaction = { raw_transaction = {
"transactionId": transaction_id, "transactionId": transaction_id,
@@ -220,15 +293,17 @@ class SampleDataGenerator:
"valueDate": transaction_date.strftime("%Y-%m-%d"), "valueDate": transaction_date.strftime("%Y-%m-%d"),
"transactionAmount": { "transactionAmount": {
"amount": str(amount), "amount": str(amount),
"currency": account["currency"] "currency": account["currency"],
}, },
"remittanceInformationUnstructured": description, "remittanceInformationUnstructured": description,
"bankTransactionCode": "PMNT" if amount < 0 else "RCDT", "bankTransactionCode": "PMNT" if amount < 0 else "RCDT",
} }
# Determine status (most are booked, some recent ones might be pending) # Determine status (most are booked, some recent ones might be pending)
status = "pending" if days_ago < 2 and random.random() < 0.1 else "booked" status = (
"pending" if days_ago < 2 and random.random() < 0.1 else "booked"
)
transaction = { transaction = {
"accountId": account["id"], "accountId": account["id"],
"transactionId": transaction_id, "transactionId": transaction_id,
@@ -242,31 +317,33 @@ class SampleDataGenerator:
"transactionStatus": status, "transactionStatus": status,
"rawTransaction": raw_transaction, "rawTransaction": raw_transaction,
} }
account_transactions.append(transaction) account_transactions.append(transaction)
current_balance += amount current_balance += amount
# Sort transactions by date for realistic ordering # Sort transactions by date for realistic ordering
account_transactions.sort(key=lambda x: x["transactionDate"]) account_transactions.sort(key=lambda x: x["transactionDate"])
transactions.extend(account_transactions) transactions.extend(account_transactions)
return transactions return transactions
def generate_balances(self, accounts: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def generate_balances(self, accounts: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Generate sample balances for accounts.""" """Generate sample balances for accounts."""
balances = [] balances = []
for account in accounts: for account in accounts:
# Calculate balance from transactions (simplified) # Calculate balance from transactions (simplified)
base_balance = random.uniform(500, 2000) base_balance = random.uniform(500, 2000)
balance_types = ["interimAvailable", "closingBooked", "authorised"] balance_types = ["interimAvailable", "closingBooked", "authorised"]
for balance_type in balance_types: for balance_type in balance_types:
# Add some variation to balance types # Add some variation to balance types
variation = random.uniform(-50, 50) if balance_type != "interimAvailable" else 0 variation = (
random.uniform(-50, 50) if balance_type != "interimAvailable" else 0
)
balance_amount = base_balance + variation balance_amount = base_balance + variation
balance = { balance = {
"account_id": account["id"], "account_id": account["id"],
"bank": account["institution_id"], "bank": account["institution_id"],
@@ -278,75 +355,113 @@ class SampleDataGenerator:
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
} }
balances.append(balance) balances.append(balance)
return balances return balances
def insert_data(self, accounts: List[Dict[str, Any]], transactions: List[Dict[str, Any]], balances: List[Dict[str, Any]]): def insert_data(
self,
accounts: List[Dict[str, Any]],
transactions: List[Dict[str, Any]],
balances: List[Dict[str, Any]],
):
"""Insert generated data into the database.""" """Insert generated data into the database."""
conn = sqlite3.connect(str(self.db_path)) conn = sqlite3.connect(str(self.db_path))
cursor = conn.cursor() cursor = conn.cursor()
# Insert accounts # Insert accounts
for account in accounts: for account in accounts:
cursor.execute(""" cursor.execute(
"""
INSERT OR REPLACE INTO accounts INSERT OR REPLACE INTO accounts
(id, institution_id, status, iban, name, currency, created, last_accessed, last_updated) (id, institution_id, status, iban, name, currency, created, last_accessed, last_updated)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""", ( """,
account["id"], account["institution_id"], account["status"], account["iban"], (
account["name"], account["currency"], account["created"], account["id"],
account["last_accessed"], account["last_updated"] account["institution_id"],
)) account["status"],
account["iban"],
account["name"],
account["currency"],
account["created"],
account["last_accessed"],
account["last_updated"],
),
)
# Insert transactions # Insert transactions
for transaction in transactions: for transaction in transactions:
cursor.execute(""" cursor.execute(
"""
INSERT OR REPLACE INTO transactions INSERT OR REPLACE INTO transactions
(accountId, transactionId, internalTransactionId, institutionId, iban, (accountId, transactionId, internalTransactionId, institutionId, iban,
transactionDate, description, transactionValue, transactionCurrency, transactionDate, description, transactionValue, transactionCurrency,
transactionStatus, rawTransaction) transactionStatus, rawTransaction)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", ( """,
transaction["accountId"], transaction["transactionId"], (
transaction["internalTransactionId"], transaction["institutionId"], transaction["accountId"],
transaction["iban"], transaction["transactionDate"], transaction["description"], transaction["transactionId"],
transaction["transactionValue"], transaction["transactionCurrency"], transaction["internalTransactionId"],
transaction["transactionStatus"], json.dumps(transaction["rawTransaction"]) transaction["institutionId"],
)) transaction["iban"],
transaction["transactionDate"],
transaction["description"],
transaction["transactionValue"],
transaction["transactionCurrency"],
transaction["transactionStatus"],
json.dumps(transaction["rawTransaction"]),
),
)
# Insert balances # Insert balances
for balance in balances: for balance in balances:
cursor.execute(""" cursor.execute(
"""
INSERT INTO balances INSERT INTO balances
(account_id, bank, status, iban, amount, currency, type, timestamp) (account_id, bank, status, iban, amount, currency, type, timestamp)
VALUES (?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", ( """,
balance["account_id"], balance["bank"], balance["status"], balance["iban"], (
balance["amount"], balance["currency"], balance["type"], balance["timestamp"] balance["account_id"],
)) balance["bank"],
balance["status"],
balance["iban"],
balance["amount"],
balance["currency"],
balance["type"],
balance["timestamp"],
),
)
conn.commit() conn.commit()
conn.close() conn.close()
def generate_sample_database(self, num_accounts: int = 3, num_transactions_per_account: int = 50): def generate_sample_database(
self, num_accounts: int = 3, num_transactions_per_account: int = 50
):
"""Generate complete sample database.""" """Generate complete sample database."""
click.echo(f"🗄️ Creating sample database at: {self.db_path}") click.echo(f"🗄️ Creating sample database at: {self.db_path}")
self.ensure_database_dir() self.ensure_database_dir()
self.create_tables() self.create_tables()
click.echo(f"👥 Generating {num_accounts} sample accounts...") click.echo(f"👥 Generating {num_accounts} sample accounts...")
accounts = self.generate_accounts(num_accounts) accounts = self.generate_accounts(num_accounts)
click.echo(f"💳 Generating {num_transactions_per_account} transactions per account...") click.echo(
transactions = self.generate_transactions(accounts, num_transactions_per_account) f"💳 Generating {num_transactions_per_account} transactions per account..."
)
transactions = self.generate_transactions(
accounts, num_transactions_per_account
)
click.echo("💰 Generating account balances...") click.echo("💰 Generating account balances...")
balances = self.generate_balances(accounts) balances = self.generate_balances(accounts)
click.echo("💾 Inserting data into database...") click.echo("💾 Inserting data into database...")
self.insert_data(accounts, transactions, balances) self.insert_data(accounts, transactions, balances)
# Print summary # Print summary
click.echo("\n✅ Sample database created successfully!") click.echo("\n✅ Sample database created successfully!")
click.echo(f"📊 Summary:") click.echo(f"📊 Summary:")
@@ -354,11 +469,15 @@ class SampleDataGenerator:
click.echo(f" - Transactions: {len(transactions)}") click.echo(f" - Transactions: {len(transactions)}")
click.echo(f" - Balances: {len(balances)}") click.echo(f" - Balances: {len(balances)}")
click.echo(f" - Database: {self.db_path}") click.echo(f" - Database: {self.db_path}")
# Show account details # Show account details
click.echo(f"\n📋 Sample accounts:") click.echo(f"\n📋 Sample accounts:")
for account in accounts: for account in accounts:
institution_name = next(inst["name"] for inst in self.institutions if inst["id"] == account["institution_id"]) institution_name = next(
inst["name"]
for inst in self.institutions
if inst["id"] == account["institution_id"]
)
click.echo(f" - {account['id']} ({institution_name}) - {account['iban']}") click.echo(f" - {account['id']} ({institution_name}) - {account['iban']}")
@@ -387,31 +506,32 @@ class SampleDataGenerator:
) )
def main(database: Path, accounts: int, transactions: int, force: bool): def main(database: Path, accounts: int, transactions: int, force: bool):
"""Generate a sample database with realistic financial data for testing Leggen.""" """Generate a sample database with realistic financial data for testing Leggen."""
# Determine database path # Determine database path
if database: if database:
db_path = database db_path = database
else: else:
# Use development database by default to avoid overwriting production data # Use development database by default to avoid overwriting production data
import os import os
env_path = os.environ.get("LEGGEN_DATABASE_PATH") env_path = os.environ.get("LEGGEN_DATABASE_PATH")
if env_path: if env_path:
db_path = Path(env_path) db_path = Path(env_path)
else: else:
# Default to development database in config directory # Default to development database in config directory
db_path = path_manager.get_config_dir() / "leggen-dev.db" db_path = path_manager.get_config_dir() / "leggen-dev.db"
# Check if database exists and ask for confirmation # Check if database exists and ask for confirmation
if db_path.exists() and not force: if db_path.exists() and not force:
click.echo(f"⚠️ Database already exists: {db_path}") click.echo(f"⚠️ Database already exists: {db_path}")
if not click.confirm("Do you want to overwrite it?"): if not click.confirm("Do you want to overwrite it?"):
click.echo("Aborted.") click.echo("Aborted.")
return return
# Generate the sample database # Generate the sample database
generator = SampleDataGenerator(db_path) generator = SampleDataGenerator(db_path)
generator.generate_sample_database(accounts, transactions) generator.generate_sample_database(accounts, transactions)
# Show usage instructions # Show usage instructions
click.echo(f"\n🚀 Usage instructions:") click.echo(f"\n🚀 Usage instructions:")
click.echo(f"To use this sample database with leggen commands:") click.echo(f"To use this sample database with leggen commands:")
@@ -423,4 +543,4 @@ def main(database: Path, accounts: int, transactions: int, force: bool):
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -87,11 +87,11 @@ def api_client(fastapi_app):
def mock_db_path(temp_db_path): def mock_db_path(temp_db_path):
"""Mock the database path to use temporary database for testing.""" """Mock the database path to use temporary database for testing."""
from leggen.utils.paths import path_manager from leggen.utils.paths import path_manager
# Set the path manager to use the temporary database # Set the path manager to use the temporary database
original_database_path = path_manager._database_path original_database_path = path_manager._database_path
path_manager.set_database_path(temp_db_path) path_manager.set_database_path(temp_db_path)
try: try:
yield temp_db_path yield temp_db_path
finally: finally:

View File

@@ -0,0 +1,118 @@
"""Tests for analytics fixes to ensure all transactions are used in statistics."""
import pytest
from datetime import datetime, timedelta
from unittest.mock import Mock, AsyncMock
from fastapi.testclient import TestClient
from leggend.main import create_app
from leggend.services.database_service import DatabaseService
class TestAnalyticsFix:
"""Test analytics fixes for transaction limits"""
@pytest.fixture
def client(self):
app = create_app()
return TestClient(app)
@pytest.fixture
def mock_database_service(self):
return Mock(spec=DatabaseService)
@pytest.mark.asyncio
async def test_transaction_stats_uses_all_transactions(self, client, mock_database_service):
"""Test that transaction stats endpoint uses all transactions (not limited to 100)"""
# Mock data for 600 transactions (simulating the issue)
mock_transactions = []
for i in range(600):
mock_transactions.append({
"transactionId": f"txn-{i}",
"transactionDate": (datetime.now() - timedelta(days=i % 365)).isoformat(),
"description": f"Transaction {i}",
"transactionValue": 10.0 if i % 2 == 0 else -5.0,
"transactionCurrency": "EUR",
"transactionStatus": "booked",
"accountId": f"account-{i % 3}",
})
mock_database_service.get_transactions_from_db = AsyncMock(return_value=mock_transactions)
# Test that the endpoint calls get_transactions_from_db with limit=None
with client as test_client:
# Replace the database service in the route handler
from leggend.api.routes import transactions
original_service = transactions.database_service
transactions.database_service = mock_database_service
try:
response = test_client.get("/api/v1/transactions/stats?days=365")
assert response.status_code == 200
data = response.json()
# Verify that limit=None was passed to get all transactions
mock_database_service.get_transactions_from_db.assert_called_once()
call_args = mock_database_service.get_transactions_from_db.call_args
assert call_args.kwargs.get("limit") is None, "Stats endpoint should pass limit=None to get all transactions"
# Verify that the response contains stats for all 600 transactions
assert data["success"] is True
stats = data["data"]
assert stats["total_transactions"] == 600, "Should process all 600 transactions, not just 100"
# Verify calculations are correct for all transactions
expected_income = sum(txn["transactionValue"] for txn in mock_transactions if txn["transactionValue"] > 0)
expected_expenses = sum(abs(txn["transactionValue"]) for txn in mock_transactions if txn["transactionValue"] < 0)
assert stats["total_income"] == expected_income
assert stats["total_expenses"] == expected_expenses
finally:
# Restore original service
transactions.database_service = original_service
@pytest.mark.asyncio
async def test_analytics_endpoint_returns_all_transactions(self, client, mock_database_service):
"""Test that the new analytics endpoint returns all transactions without pagination"""
# Mock data for 600 transactions
mock_transactions = []
for i in range(600):
mock_transactions.append({
"transactionId": f"txn-{i}",
"transactionDate": (datetime.now() - timedelta(days=i % 365)).isoformat(),
"description": f"Transaction {i}",
"transactionValue": 10.0 if i % 2 == 0 else -5.0,
"transactionCurrency": "EUR",
"transactionStatus": "booked",
"accountId": f"account-{i % 3}",
})
mock_database_service.get_transactions_from_db = AsyncMock(return_value=mock_transactions)
with client as test_client:
# Replace the database service in the route handler
from leggend.api.routes import transactions
original_service = transactions.database_service
transactions.database_service = mock_database_service
try:
response = test_client.get("/api/v1/transactions/analytics?days=365")
assert response.status_code == 200
data = response.json()
# Verify that limit=None was passed to get all transactions
mock_database_service.get_transactions_from_db.assert_called_once()
call_args = mock_database_service.get_transactions_from_db.call_args
assert call_args.kwargs.get("limit") is None, "Analytics endpoint should pass limit=None"
# Verify that all 600 transactions are returned
assert data["success"] is True
transactions_data = data["data"]
assert len(transactions_data) == 600, "Analytics endpoint should return all 600 transactions"
finally:
# Restore original service
transactions.database_service = original_service

View File

@@ -12,6 +12,7 @@ from leggen.database.sqlite import persist_balances, get_balances
class MockContext: class MockContext:
"""Mock context for testing.""" """Mock context for testing."""
pass pass
@@ -24,15 +25,15 @@ class TestConfigurablePaths:
# Reset path manager # Reset path manager
original_config = path_manager._config_dir original_config = path_manager._config_dir
original_db = path_manager._database_path original_db = path_manager._database_path
try: try:
path_manager._config_dir = None path_manager._config_dir = None
path_manager._database_path = None path_manager._database_path = None
# Test defaults # Test defaults
config_dir = path_manager.get_config_dir() config_dir = path_manager.get_config_dir()
db_path = path_manager.get_database_path() db_path = path_manager.get_database_path()
assert config_dir == Path.home() / ".config" / "leggen" assert config_dir == Path.home() / ".config" / "leggen"
assert db_path == Path.home() / ".config" / "leggen" / "leggen.db" assert db_path == Path.home() / ".config" / "leggen" / "leggen.db"
finally: finally:
@@ -44,22 +45,25 @@ class TestConfigurablePaths:
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
test_config_dir = Path(tmpdir) / "test-config" test_config_dir = Path(tmpdir) / "test-config"
test_db_path = Path(tmpdir) / "test.db" test_db_path = Path(tmpdir) / "test.db"
with patch.dict(os.environ, { with patch.dict(
'LEGGEN_CONFIG_DIR': str(test_config_dir), os.environ,
'LEGGEN_DATABASE_PATH': str(test_db_path) {
}): "LEGGEN_CONFIG_DIR": str(test_config_dir),
"LEGGEN_DATABASE_PATH": str(test_db_path),
},
):
# Reset path manager to pick up environment variables # Reset path manager to pick up environment variables
original_config = path_manager._config_dir original_config = path_manager._config_dir
original_db = path_manager._database_path original_db = path_manager._database_path
try: try:
path_manager._config_dir = None path_manager._config_dir = None
path_manager._database_path = None path_manager._database_path = None
config_dir = path_manager.get_config_dir() config_dir = path_manager.get_config_dir()
db_path = path_manager.get_database_path() db_path = path_manager.get_database_path()
assert config_dir == test_config_dir assert config_dir == test_config_dir
assert db_path == test_db_path assert db_path == test_db_path
finally: finally:
@@ -71,20 +75,25 @@ class TestConfigurablePaths:
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
test_config_dir = Path(tmpdir) / "explicit-config" test_config_dir = Path(tmpdir) / "explicit-config"
test_db_path = Path(tmpdir) / "explicit.db" test_db_path = Path(tmpdir) / "explicit.db"
# Save original paths # Save original paths
original_config = path_manager._config_dir original_config = path_manager._config_dir
original_db = path_manager._database_path original_db = path_manager._database_path
try: try:
# Set explicit paths # Set explicit paths
path_manager.set_config_dir(test_config_dir) path_manager.set_config_dir(test_config_dir)
path_manager.set_database_path(test_db_path) path_manager.set_database_path(test_db_path)
assert path_manager.get_config_dir() == test_config_dir assert path_manager.get_config_dir() == test_config_dir
assert path_manager.get_database_path() == test_db_path assert path_manager.get_database_path() == test_db_path
assert path_manager.get_config_file_path() == test_config_dir / "config.toml" assert (
assert path_manager.get_auth_file_path() == test_config_dir / "auth.json" path_manager.get_config_file_path()
== test_config_dir / "config.toml"
)
assert (
path_manager.get_auth_file_path() == test_config_dir / "auth.json"
)
finally: finally:
# Restore original paths # Restore original paths
path_manager._config_dir = original_config path_manager._config_dir = original_config
@@ -94,14 +103,14 @@ class TestConfigurablePaths:
"""Test that database operations work with custom paths.""" """Test that database operations work with custom paths."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_file: with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_file:
test_db_path = Path(tmp_file.name) test_db_path = Path(tmp_file.name)
# Save original database path # Save original database path
original_db = path_manager._database_path original_db = path_manager._database_path
try: try:
# Set custom database path # Set custom database path
path_manager.set_database_path(test_db_path) path_manager.set_database_path(test_db_path)
# Test database operations # Test database operations
ctx = MockContext() ctx = MockContext()
balance = { balance = {
@@ -114,20 +123,20 @@ class TestConfigurablePaths:
"type": "available", "type": "available",
"timestamp": "2023-01-01T00:00:00", "timestamp": "2023-01-01T00:00:00",
} }
# Persist balance # Persist balance
persist_balances(ctx, balance) persist_balances(ctx, balance)
# Retrieve balances # Retrieve balances
balances = get_balances() balances = get_balances()
assert len(balances) == 1 assert len(balances) == 1
assert balances[0]["account_id"] == "test-account" assert balances[0]["account_id"] == "test-account"
assert balances[0]["amount"] == 1000.0 assert balances[0]["amount"] == 1000.0
# Verify database file exists at custom location # Verify database file exists at custom location
assert test_db_path.exists() assert test_db_path.exists()
finally: finally:
# Restore original path and cleanup # Restore original path and cleanup
path_manager._database_path = original_db path_manager._database_path = original_db
@@ -139,24 +148,24 @@ class TestConfigurablePaths:
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
test_config_dir = Path(tmpdir) / "new" / "config" / "dir" test_config_dir = Path(tmpdir) / "new" / "config" / "dir"
test_db_path = Path(tmpdir) / "new" / "db" / "dir" / "test.db" test_db_path = Path(tmpdir) / "new" / "db" / "dir" / "test.db"
# Save original paths # Save original paths
original_config = path_manager._config_dir original_config = path_manager._config_dir
original_db = path_manager._database_path original_db = path_manager._database_path
try: try:
# Set paths to non-existent directories # Set paths to non-existent directories
path_manager.set_config_dir(test_config_dir) path_manager.set_config_dir(test_config_dir)
path_manager.set_database_path(test_db_path) path_manager.set_database_path(test_db_path)
# Ensure directories are created # Ensure directories are created
path_manager.ensure_config_dir_exists() path_manager.ensure_config_dir_exists()
path_manager.ensure_database_dir_exists() path_manager.ensure_database_dir_exists()
assert test_config_dir.exists() assert test_config_dir.exists()
assert test_db_path.parent.exists() assert test_db_path.parent.exists()
finally: finally:
# Restore original paths # Restore original paths
path_manager._config_dir = original_config path_manager._config_dir = original_config
path_manager._database_path = original_db path_manager._database_path = original_db

View File

@@ -23,11 +23,11 @@ def temp_db_path():
def mock_home_db_path(temp_db_path): def mock_home_db_path(temp_db_path):
"""Mock the database path to use temp file.""" """Mock the database path to use temp file."""
from leggen.utils.paths import path_manager from leggen.utils.paths import path_manager
# Set the path manager to use the temporary database # Set the path manager to use the temporary database
original_database_path = path_manager._database_path original_database_path = path_manager._database_path
path_manager.set_database_path(temp_db_path) path_manager.set_database_path(temp_db_path)
try: try:
yield temp_db_path yield temp_db_path
finally: finally: