"""Cost tracking — monitor token usage and enforce budget limits.""" from __future__ import annotations import time from dataclasses import dataclass, field from pathlib import Path import aiosqlite from loguru import logger # Pricing per 1M tokens (input, output) in USD MODEL_PRICING: dict[str, tuple[float, float]] = { # Anthropic "claude-sonnet-4-5-20250929": (3.0, 15.0), "claude-haiku-4-5-20251001": (0.80, 4.0), "claude-opus-4-6": (15.0, 75.0), # NVIDIA NIM / LiteLLM "nvidia_nim/deepseek-ai/deepseek-v3.1": (0.27, 1.10), "nvidia_nim/moonshotai/kimi-k2.5": (0.35, 1.40), "nvidia_nim/minimaxai/minimax-m2.1": (0.30, 1.10), # Fallback "_default": (1.0, 3.0), } def _get_pricing(model: str) -> tuple[float, float]: """Get (input_price, output_price) per 1M tokens for a model.""" return MODEL_PRICING.get(model, MODEL_PRICING["_default"]) @dataclass class CostRecord: model: str input_tokens: int output_tokens: int cost_usd: float timestamp: float @dataclass class BudgetConfig: daily_limit_usd: float = 0.0 # 0 = no limit monthly_limit_usd: float = 0.0 # 0 = no limit warning_threshold: float = 0.8 # warn at 80% of budget class CostTracker: """Track LLM costs in SQLite and enforce budget limits.""" def __init__( self, db_path: str | Path = "data/costs.db", budget: BudgetConfig | None = None, ) -> None: self._db_path = str(db_path) self._budget = budget or BudgetConfig() self._db: aiosqlite.Connection | None = None async def setup(self) -> None: Path(self._db_path).parent.mkdir(parents=True, exist_ok=True) self._db = await aiosqlite.connect(self._db_path) await self._db.execute("PRAGMA journal_mode=WAL") await self._db.execute( """ CREATE TABLE IF NOT EXISTS cost_records ( id INTEGER PRIMARY KEY AUTOINCREMENT, model TEXT NOT NULL, input_tokens INTEGER NOT NULL, output_tokens INTEGER NOT NULL, cost_usd REAL NOT NULL, timestamp REAL NOT NULL ) """ ) await self._db.commit() async def close(self) -> None: if self._db: await self._db.close() self._db = None async def record(self, model: str, usage: dict[str, int]) -> CostRecord: """Record token usage and return the cost record.""" input_tokens = usage.get("input_tokens", 0) output_tokens = usage.get("output_tokens", 0) input_price, output_price = _get_pricing(model) cost = (input_tokens * input_price + output_tokens * output_price) / 1_000_000 record = CostRecord( model=model, input_tokens=input_tokens, output_tokens=output_tokens, cost_usd=cost, timestamp=time.time(), ) if self._db: await self._db.execute( "INSERT INTO cost_records (model, input_tokens, output_tokens, cost_usd, timestamp) VALUES (?, ?, ?, ?, ?)", (model, input_tokens, output_tokens, cost, record.timestamp), ) await self._db.commit() logger.debug( f"Cost: {model} — {input_tokens}in/{output_tokens}out = ${cost:.6f}" ) return record async def check_budget(self) -> str: """Check budget status. Returns 'ok', 'warning', or 'exceeded'.""" if not self._db: return "ok" now = time.time() # Daily check if self._budget.daily_limit_usd > 0: day_start = now - (now % 86400) daily_total = await self._sum_costs_since(day_start) if daily_total >= self._budget.daily_limit_usd: logger.warning(f"Daily budget EXCEEDED: ${daily_total:.4f} / ${self._budget.daily_limit_usd:.2f}") return "exceeded" if daily_total >= self._budget.daily_limit_usd * self._budget.warning_threshold: logger.warning(f"Daily budget warning: ${daily_total:.4f} / ${self._budget.daily_limit_usd:.2f}") return "warning" # Monthly check if self._budget.monthly_limit_usd > 0: month_start = now - (now % (86400 * 30)) monthly_total = await self._sum_costs_since(month_start) if monthly_total >= self._budget.monthly_limit_usd: logger.warning(f"Monthly budget EXCEEDED: ${monthly_total:.4f} / ${self._budget.monthly_limit_usd:.2f}") return "exceeded" if monthly_total >= self._budget.monthly_limit_usd * self._budget.warning_threshold: logger.warning(f"Monthly budget warning: ${monthly_total:.4f} / ${self._budget.monthly_limit_usd:.2f}") return "warning" return "ok" async def get_summary(self) -> dict: """Get a cost summary with per-model breakdown.""" if not self._db: return {"total_usd": 0, "models": {}} now = time.time() day_start = now - (now % 86400) rows: list = [] async with self._db.execute( "SELECT model, SUM(input_tokens), SUM(output_tokens), SUM(cost_usd) FROM cost_records WHERE timestamp >= ? GROUP BY model", (day_start,), ) as cursor: rows = await cursor.fetchall() models = {} total = 0.0 for model, inp, out, cost in rows: models[model] = {"input_tokens": inp, "output_tokens": out, "cost_usd": cost} total += cost return {"total_usd": total, "models": models} async def _sum_costs_since(self, since: float) -> float: if not self._db: return 0.0 async with self._db.execute( "SELECT COALESCE(SUM(cost_usd), 0) FROM cost_records WHERE timestamp >= ?", (since,), ) as cursor: row = await cursor.fetchone() return row[0] if row else 0.0