Inspired by zeroclaw's lightweight patterns for slow hardware: - Response cache (SQLite + SHA-256 keyed) to skip redundant LLM calls - History compaction — LLM-summarize old messages when history exceeds 50 - Query classifier routes simple/research queries to cheaper models - Credential scrubbing removes secrets from tool output before sending to LLM - Cost tracker with daily/monthly budget enforcement (SQLite) - Resilient provider with retry + exponential backoff + fallback provider - Approval engine gains session "always allow" and audit log Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
174 lines
5.9 KiB
Python
174 lines
5.9 KiB
Python
"""Cost tracking — monitor token usage and enforce budget limits."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
|
|
import aiosqlite
|
|
from loguru import logger
|
|
|
|
|
|
# Pricing per 1M tokens (input, output) in USD
|
|
MODEL_PRICING: dict[str, tuple[float, float]] = {
|
|
# Anthropic
|
|
"claude-sonnet-4-5-20250929": (3.0, 15.0),
|
|
"claude-haiku-4-5-20251001": (0.80, 4.0),
|
|
"claude-opus-4-6": (15.0, 75.0),
|
|
# NVIDIA NIM / LiteLLM
|
|
"nvidia_nim/deepseek-ai/deepseek-v3.1": (0.27, 1.10),
|
|
"nvidia_nim/moonshotai/kimi-k2.5": (0.35, 1.40),
|
|
"nvidia_nim/minimaxai/minimax-m2.1": (0.30, 1.10),
|
|
# Fallback
|
|
"_default": (1.0, 3.0),
|
|
}
|
|
|
|
|
|
def _get_pricing(model: str) -> tuple[float, float]:
|
|
"""Get (input_price, output_price) per 1M tokens for a model."""
|
|
return MODEL_PRICING.get(model, MODEL_PRICING["_default"])
|
|
|
|
|
|
@dataclass
|
|
class CostRecord:
|
|
model: str
|
|
input_tokens: int
|
|
output_tokens: int
|
|
cost_usd: float
|
|
timestamp: float
|
|
|
|
|
|
@dataclass
|
|
class BudgetConfig:
|
|
daily_limit_usd: float = 0.0 # 0 = no limit
|
|
monthly_limit_usd: float = 0.0 # 0 = no limit
|
|
warning_threshold: float = 0.8 # warn at 80% of budget
|
|
|
|
|
|
class CostTracker:
|
|
"""Track LLM costs in SQLite and enforce budget limits."""
|
|
|
|
def __init__(
|
|
self,
|
|
db_path: str | Path = "data/costs.db",
|
|
budget: BudgetConfig | None = None,
|
|
) -> None:
|
|
self._db_path = str(db_path)
|
|
self._budget = budget or BudgetConfig()
|
|
self._db: aiosqlite.Connection | None = None
|
|
|
|
async def setup(self) -> None:
|
|
Path(self._db_path).parent.mkdir(parents=True, exist_ok=True)
|
|
self._db = await aiosqlite.connect(self._db_path)
|
|
await self._db.execute("PRAGMA journal_mode=WAL")
|
|
await self._db.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS cost_records (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
model TEXT NOT NULL,
|
|
input_tokens INTEGER NOT NULL,
|
|
output_tokens INTEGER NOT NULL,
|
|
cost_usd REAL NOT NULL,
|
|
timestamp REAL NOT NULL
|
|
)
|
|
"""
|
|
)
|
|
await self._db.commit()
|
|
|
|
async def close(self) -> None:
|
|
if self._db:
|
|
await self._db.close()
|
|
self._db = None
|
|
|
|
async def record(self, model: str, usage: dict[str, int]) -> CostRecord:
|
|
"""Record token usage and return the cost record."""
|
|
input_tokens = usage.get("input_tokens", 0)
|
|
output_tokens = usage.get("output_tokens", 0)
|
|
input_price, output_price = _get_pricing(model)
|
|
cost = (input_tokens * input_price + output_tokens * output_price) / 1_000_000
|
|
|
|
record = CostRecord(
|
|
model=model,
|
|
input_tokens=input_tokens,
|
|
output_tokens=output_tokens,
|
|
cost_usd=cost,
|
|
timestamp=time.time(),
|
|
)
|
|
|
|
if self._db:
|
|
await self._db.execute(
|
|
"INSERT INTO cost_records (model, input_tokens, output_tokens, cost_usd, timestamp) VALUES (?, ?, ?, ?, ?)",
|
|
(model, input_tokens, output_tokens, cost, record.timestamp),
|
|
)
|
|
await self._db.commit()
|
|
|
|
logger.debug(
|
|
f"Cost: {model} — {input_tokens}in/{output_tokens}out = ${cost:.6f}"
|
|
)
|
|
return record
|
|
|
|
async def check_budget(self) -> str:
|
|
"""Check budget status. Returns 'ok', 'warning', or 'exceeded'."""
|
|
if not self._db:
|
|
return "ok"
|
|
|
|
now = time.time()
|
|
|
|
# Daily check
|
|
if self._budget.daily_limit_usd > 0:
|
|
day_start = now - (now % 86400)
|
|
daily_total = await self._sum_costs_since(day_start)
|
|
if daily_total >= self._budget.daily_limit_usd:
|
|
logger.warning(f"Daily budget EXCEEDED: ${daily_total:.4f} / ${self._budget.daily_limit_usd:.2f}")
|
|
return "exceeded"
|
|
if daily_total >= self._budget.daily_limit_usd * self._budget.warning_threshold:
|
|
logger.warning(f"Daily budget warning: ${daily_total:.4f} / ${self._budget.daily_limit_usd:.2f}")
|
|
return "warning"
|
|
|
|
# Monthly check
|
|
if self._budget.monthly_limit_usd > 0:
|
|
month_start = now - (now % (86400 * 30))
|
|
monthly_total = await self._sum_costs_since(month_start)
|
|
if monthly_total >= self._budget.monthly_limit_usd:
|
|
logger.warning(f"Monthly budget EXCEEDED: ${monthly_total:.4f} / ${self._budget.monthly_limit_usd:.2f}")
|
|
return "exceeded"
|
|
if monthly_total >= self._budget.monthly_limit_usd * self._budget.warning_threshold:
|
|
logger.warning(f"Monthly budget warning: ${monthly_total:.4f} / ${self._budget.monthly_limit_usd:.2f}")
|
|
return "warning"
|
|
|
|
return "ok"
|
|
|
|
async def get_summary(self) -> dict:
|
|
"""Get a cost summary with per-model breakdown."""
|
|
if not self._db:
|
|
return {"total_usd": 0, "models": {}}
|
|
|
|
now = time.time()
|
|
day_start = now - (now % 86400)
|
|
|
|
rows: list = []
|
|
async with self._db.execute(
|
|
"SELECT model, SUM(input_tokens), SUM(output_tokens), SUM(cost_usd) FROM cost_records WHERE timestamp >= ? GROUP BY model",
|
|
(day_start,),
|
|
) as cursor:
|
|
rows = await cursor.fetchall()
|
|
|
|
models = {}
|
|
total = 0.0
|
|
for model, inp, out, cost in rows:
|
|
models[model] = {"input_tokens": inp, "output_tokens": out, "cost_usd": cost}
|
|
total += cost
|
|
|
|
return {"total_usd": total, "models": models}
|
|
|
|
async def _sum_costs_since(self, since: float) -> float:
|
|
if not self._db:
|
|
return 0.0
|
|
async with self._db.execute(
|
|
"SELECT COALESCE(SUM(cost_usd), 0) FROM cost_records WHERE timestamp >= ?",
|
|
(since,),
|
|
) as cursor:
|
|
row = await cursor.fetchone()
|
|
return row[0] if row else 0.0
|