Add performance features: caching, cost tracking, retry, compaction, classification, scrubbing

Inspired by zeroclaw's lightweight patterns for slow hardware:
- Response cache (SQLite + SHA-256 keyed) to skip redundant LLM calls
- History compaction — LLM-summarize old messages when history exceeds 50
- Query classifier routes simple/research queries to cheaper models
- Credential scrubbing removes secrets from tool output before sending to LLM
- Cost tracker with daily/monthly budget enforcement (SQLite)
- Resilient provider with retry + exponential backoff + fallback provider
- Approval engine gains session "always allow" and audit log

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Kaloyan Danchev
2026-02-19 09:20:52 +02:00
parent e24e3026b6
commit 872ed24f0c
10 changed files with 694 additions and 17 deletions

View File

@@ -43,3 +43,12 @@ agents:
orchestrator:
max_concurrent: 5
delegation_timeout: 120
performance:
cache_ttl: 3600
daily_budget_usd: 5.0
monthly_budget_usd: 50.0
fallback_model: nvidia_nim/deepseek-ai/deepseek-v3.1
model_routing:
simple: nvidia_nim/deepseek-ai/deepseek-v3.1
research: nvidia_nim/moonshotai/kimi-k2.5

View File

@@ -19,6 +19,7 @@ dependencies = [
"json-repair>=0.30.0",
"duckduckgo-search>=7.0.0",
"pypdf>=5.0.0",
"aiosqlite>=0.20.0",
]
[project.scripts]

100
xtrm_agent/cache.py Normal file
View File

@@ -0,0 +1,100 @@
"""Response cache — avoid redundant LLM calls for identical prompts."""
from __future__ import annotations
import hashlib
import json
import time
from pathlib import Path
import aiosqlite
from loguru import logger
class ResponseCache:
"""SQLite-backed LLM response cache with TTL expiry."""
def __init__(self, db_path: str | Path = "data/cache.db", ttl: int = 3600) -> None:
self._db_path = str(db_path)
self._ttl = ttl
self._db: aiosqlite.Connection | None = None
async def setup(self) -> None:
"""Create the cache table if it doesn't exist."""
Path(self._db_path).parent.mkdir(parents=True, exist_ok=True)
self._db = await aiosqlite.connect(self._db_path)
await self._db.execute("PRAGMA journal_mode=WAL")
await self._db.execute(
"""
CREATE TABLE IF NOT EXISTS response_cache (
key TEXT PRIMARY KEY,
response TEXT NOT NULL,
model TEXT NOT NULL,
created_at REAL NOT NULL,
hits INTEGER DEFAULT 0
)
"""
)
await self._db.commit()
async def close(self) -> None:
if self._db:
await self._db.close()
self._db = None
@staticmethod
def _make_key(model: str, messages: list[dict]) -> str:
"""SHA-256 hash of model + messages for cache key."""
raw = json.dumps({"model": model, "messages": messages}, sort_keys=True)
return hashlib.sha256(raw.encode()).hexdigest()
async def get(self, model: str, messages: list[dict]) -> str | None:
"""Look up a cached response. Returns None on miss or expired."""
if not self._db:
return None
key = self._make_key(model, messages)
now = time.time()
async with self._db.execute(
"SELECT response, created_at FROM response_cache WHERE key = ?",
(key,),
) as cursor:
row = await cursor.fetchone()
if not row:
return None
response, created_at = row
if now - created_at > self._ttl:
await self._db.execute("DELETE FROM response_cache WHERE key = ?", (key,))
await self._db.commit()
return None
# Bump hit count
await self._db.execute(
"UPDATE response_cache SET hits = hits + 1 WHERE key = ?", (key,)
)
await self._db.commit()
logger.debug(f"Cache hit for {model} (key={key[:12]}...)")
return response
async def put(self, model: str, messages: list[dict], response: str) -> None:
"""Store a response in the cache."""
if not self._db:
return
key = self._make_key(model, messages)
await self._db.execute(
"""
INSERT OR REPLACE INTO response_cache (key, response, model, created_at, hits)
VALUES (?, ?, ?, ?, 0)
""",
(key, response, model, time.time()),
)
await self._db.commit()
async def clear_expired(self) -> int:
"""Remove expired entries. Returns count of deleted rows."""
if not self._db:
return 0
cutoff = time.time() - self._ttl
cursor = await self._db.execute(
"DELETE FROM response_cache WHERE created_at < ?", (cutoff,)
)
await self._db.commit()
return cursor.rowcount

86
xtrm_agent/classifier.py Normal file
View File

@@ -0,0 +1,86 @@
"""Query classifier — route queries to appropriate models."""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Any
@dataclass
class ClassificationResult:
"""Result of classifying a query."""
category: str # "simple", "code", "research", "complex"
model_hint: str # suggested model override, or "" to use default
# Short, trivial queries — greetings, thanks, yes/no, etc.
_SIMPLE_PATTERNS = [
re.compile(r"^(hi|hello|hey|thanks|thank you|ok|yes|no|sure|bye|good\s*(morning|evening|night))[\s!.?]*$", re.IGNORECASE),
re.compile(r"^(what time|what date|what day)[\s\S]{0,20}\??\s*$", re.IGNORECASE),
]
# Code-related queries
_CODE_PATTERNS = [
re.compile(r"\b(function|class|def|import|const|let|var|return|async|await)\b"),
re.compile(r"\b(refactor|debug|fix|implement|code|compile|build|test|lint|deploy)\b", re.IGNORECASE),
re.compile(r"```"),
re.compile(r"\.(py|js|ts|rs|go|java|cpp|c|rb|php|swift|kt)\b"),
]
# Research / information queries
_RESEARCH_PATTERNS = [
re.compile(r"\b(search|find|look\s*up|research|summarize|explain|compare|analyze)\b", re.IGNORECASE),
re.compile(r"\b(what is|how does|why does|can you explain)\b", re.IGNORECASE),
]
class QueryClassifier:
"""Classify queries for model routing.
Configure with a mapping of category → model override.
Categories: "simple", "code", "research", "complex".
"""
def __init__(self, model_map: dict[str, str] | None = None) -> None:
self._model_map = model_map or {}
def classify(self, query: str) -> ClassificationResult:
"""Classify a query and suggest a model."""
query_stripped = query.strip()
# Very short queries are likely simple
if len(query_stripped) < 20:
for pattern in _SIMPLE_PATTERNS:
if pattern.match(query_stripped):
return ClassificationResult(
category="simple",
model_hint=self._model_map.get("simple", ""),
)
# Code patterns
code_score = sum(1 for p in _CODE_PATTERNS if p.search(query_stripped))
if code_score >= 2:
return ClassificationResult(
category="code",
model_hint=self._model_map.get("code", ""),
)
# Research patterns
research_score = sum(1 for p in _RESEARCH_PATTERNS if p.search(query_stripped))
if research_score >= 1:
return ClassificationResult(
category="research",
model_hint=self._model_map.get("research", ""),
)
# Long or complex queries
if len(query_stripped) > 500 or query_stripped.count("\n") > 5:
return ClassificationResult(
category="complex",
model_hint=self._model_map.get("complex", ""),
)
# Default — no override
return ClassificationResult(category="complex", model_hint="")

View File

@@ -61,6 +61,16 @@ class MCPServerConfig(BaseModel):
url: str = ""
class PerformanceConfig(BaseModel):
"""Performance tuning — caching, cost tracking, model routing."""
cache_ttl: int = 3600
daily_budget_usd: float = 0.0
monthly_budget_usd: float = 0.0
fallback_model: str = ""
model_routing: dict[str, str] = Field(default_factory=dict)
class OrchestratorConfig(BaseModel):
max_concurrent: int = 5
delegation_timeout: int = 120
@@ -87,6 +97,7 @@ class Config(BaseModel):
mcp_servers: dict[str, MCPServerConfig] = Field(default_factory=dict)
agents: dict[str, str] = Field(default_factory=dict)
orchestrator: OrchestratorConfig = Field(default_factory=OrchestratorConfig)
performance: PerformanceConfig = Field(default_factory=PerformanceConfig)
def load_config(path: str | Path = "config.yaml") -> Config:

173
xtrm_agent/cost.py Normal file
View File

@@ -0,0 +1,173 @@
"""Cost tracking — monitor token usage and enforce budget limits."""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from pathlib import Path
import aiosqlite
from loguru import logger
# Pricing per 1M tokens (input, output) in USD
MODEL_PRICING: dict[str, tuple[float, float]] = {
# Anthropic
"claude-sonnet-4-5-20250929": (3.0, 15.0),
"claude-haiku-4-5-20251001": (0.80, 4.0),
"claude-opus-4-6": (15.0, 75.0),
# NVIDIA NIM / LiteLLM
"nvidia_nim/deepseek-ai/deepseek-v3.1": (0.27, 1.10),
"nvidia_nim/moonshotai/kimi-k2.5": (0.35, 1.40),
"nvidia_nim/minimaxai/minimax-m2.1": (0.30, 1.10),
# Fallback
"_default": (1.0, 3.0),
}
def _get_pricing(model: str) -> tuple[float, float]:
"""Get (input_price, output_price) per 1M tokens for a model."""
return MODEL_PRICING.get(model, MODEL_PRICING["_default"])
@dataclass
class CostRecord:
model: str
input_tokens: int
output_tokens: int
cost_usd: float
timestamp: float
@dataclass
class BudgetConfig:
daily_limit_usd: float = 0.0 # 0 = no limit
monthly_limit_usd: float = 0.0 # 0 = no limit
warning_threshold: float = 0.8 # warn at 80% of budget
class CostTracker:
"""Track LLM costs in SQLite and enforce budget limits."""
def __init__(
self,
db_path: str | Path = "data/costs.db",
budget: BudgetConfig | None = None,
) -> None:
self._db_path = str(db_path)
self._budget = budget or BudgetConfig()
self._db: aiosqlite.Connection | None = None
async def setup(self) -> None:
Path(self._db_path).parent.mkdir(parents=True, exist_ok=True)
self._db = await aiosqlite.connect(self._db_path)
await self._db.execute("PRAGMA journal_mode=WAL")
await self._db.execute(
"""
CREATE TABLE IF NOT EXISTS cost_records (
id INTEGER PRIMARY KEY AUTOINCREMENT,
model TEXT NOT NULL,
input_tokens INTEGER NOT NULL,
output_tokens INTEGER NOT NULL,
cost_usd REAL NOT NULL,
timestamp REAL NOT NULL
)
"""
)
await self._db.commit()
async def close(self) -> None:
if self._db:
await self._db.close()
self._db = None
async def record(self, model: str, usage: dict[str, int]) -> CostRecord:
"""Record token usage and return the cost record."""
input_tokens = usage.get("input_tokens", 0)
output_tokens = usage.get("output_tokens", 0)
input_price, output_price = _get_pricing(model)
cost = (input_tokens * input_price + output_tokens * output_price) / 1_000_000
record = CostRecord(
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cost_usd=cost,
timestamp=time.time(),
)
if self._db:
await self._db.execute(
"INSERT INTO cost_records (model, input_tokens, output_tokens, cost_usd, timestamp) VALUES (?, ?, ?, ?, ?)",
(model, input_tokens, output_tokens, cost, record.timestamp),
)
await self._db.commit()
logger.debug(
f"Cost: {model}{input_tokens}in/{output_tokens}out = ${cost:.6f}"
)
return record
async def check_budget(self) -> str:
"""Check budget status. Returns 'ok', 'warning', or 'exceeded'."""
if not self._db:
return "ok"
now = time.time()
# Daily check
if self._budget.daily_limit_usd > 0:
day_start = now - (now % 86400)
daily_total = await self._sum_costs_since(day_start)
if daily_total >= self._budget.daily_limit_usd:
logger.warning(f"Daily budget EXCEEDED: ${daily_total:.4f} / ${self._budget.daily_limit_usd:.2f}")
return "exceeded"
if daily_total >= self._budget.daily_limit_usd * self._budget.warning_threshold:
logger.warning(f"Daily budget warning: ${daily_total:.4f} / ${self._budget.daily_limit_usd:.2f}")
return "warning"
# Monthly check
if self._budget.monthly_limit_usd > 0:
month_start = now - (now % (86400 * 30))
monthly_total = await self._sum_costs_since(month_start)
if monthly_total >= self._budget.monthly_limit_usd:
logger.warning(f"Monthly budget EXCEEDED: ${monthly_total:.4f} / ${self._budget.monthly_limit_usd:.2f}")
return "exceeded"
if monthly_total >= self._budget.monthly_limit_usd * self._budget.warning_threshold:
logger.warning(f"Monthly budget warning: ${monthly_total:.4f} / ${self._budget.monthly_limit_usd:.2f}")
return "warning"
return "ok"
async def get_summary(self) -> dict:
"""Get a cost summary with per-model breakdown."""
if not self._db:
return {"total_usd": 0, "models": {}}
now = time.time()
day_start = now - (now % 86400)
rows: list = []
async with self._db.execute(
"SELECT model, SUM(input_tokens), SUM(output_tokens), SUM(cost_usd) FROM cost_records WHERE timestamp >= ? GROUP BY model",
(day_start,),
) as cursor:
rows = await cursor.fetchall()
models = {}
total = 0.0
for model, inp, out, cost in rows:
models[model] = {"input_tokens": inp, "output_tokens": out, "cost_usd": cost}
total += cost
return {"total_usd": total, "models": models}
async def _sum_costs_since(self, since: float) -> float:
if not self._db:
return 0.0
async with self._db.execute(
"SELECT COALESCE(SUM(cost_usd), 0) FROM cost_records WHERE timestamp >= ?",
(since,),
) as cursor:
row = await cursor.fetchone()
return row[0] if row else 0.0

View File

@@ -2,17 +2,45 @@
from __future__ import annotations
import asyncio
import json
from typing import Any
from loguru import logger
from xtrm_agent.cache import ResponseCache
from xtrm_agent.classifier import ClassificationResult, QueryClassifier
from xtrm_agent.config import AgentFileConfig
import json
from xtrm_agent.cost import CostTracker
from xtrm_agent.llm.provider import LLMProvider, LLMResponse
from xtrm_agent.scrub import scrub_credentials
from xtrm_agent.tools.approval import ApprovalEngine
from xtrm_agent.tools.registry import ToolRegistry
# History compaction settings
_MAX_HISTORY_MESSAGES = 50
_KEEP_RECENT = 20
_COMPACTION_PROMPT = (
"Summarize the following conversation history concisely. "
"Preserve key facts, decisions, tool results, and context needed to continue. "
"Be brief but complete."
)
# Retry settings
_MAX_RETRIES = 3
_RETRY_BACKOFF_BASE = 2.0
_NON_RETRYABLE_KEYWORDS = [
"authentication",
"unauthorized",
"invalid api key",
"model not found",
"permission denied",
"invalid_api_key",
"401",
"403",
"404",
]
class Engine:
"""Runs one agent's LLM loop: messages → LLM → tool calls → loop → response."""
@@ -23,16 +51,33 @@ class Engine:
provider: LLMProvider,
tools: ToolRegistry,
approval: ApprovalEngine,
cache: ResponseCache | None = None,
cost_tracker: CostTracker | None = None,
classifier: QueryClassifier | None = None,
fallback_provider: LLMProvider | None = None,
) -> None:
self.config = agent_config
self.provider = provider
self.tools = tools
self.approval = approval
self.cache = cache
self.cost_tracker = cost_tracker
self.classifier = classifier
self.fallback_provider = fallback_provider
async def run(self, user_message: str) -> str:
"""Process a single user message through the agent loop."""
messages = self._build_initial_messages(user_message)
return await self._agent_loop(messages)
# Query classification — override model if classifier suggests one
model_override = ""
if self.classifier:
result = self.classifier.classify(user_message)
if result.model_hint:
model_override = result.model_hint
logger.debug(f"[{self.config.name}] Classified as '{result.category}' → model={result.model_hint}")
return await self._agent_loop(messages, model_override=model_override)
async def run_delegation(self, task: str) -> str:
"""Process a delegation task (no system prompt changes)."""
@@ -46,22 +91,45 @@ class Engine:
messages.append({"role": "user", "content": user_message})
return messages
async def _agent_loop(self, messages: list[dict[str, Any]]) -> str:
"""Core agent iteration loop."""
async def _agent_loop(
self,
messages: list[dict[str, Any]],
model_override: str = "",
) -> str:
"""Core agent iteration loop with caching, cost tracking, retry, and compaction."""
for iteration in range(self.config.max_iterations):
model = self.config.model or self.provider.get_default_model()
model = model_override or self.config.model or self.provider.get_default_model()
tool_defs = self.tools.get_definitions() if self.tools.names() else None
response = await self.provider.complete(
# Budget check
if self.cost_tracker:
status = await self.cost_tracker.check_budget()
if status == "exceeded":
return "(budget exceeded — please try again later)"
# Response cache check (only for first iteration — no tool calls yet)
if self.cache and iteration == 0 and not tool_defs:
cached = await self.cache.get(model, messages)
if cached:
return cached
# LLM call with retry + fallback
response = await self._call_with_retry(
messages=messages,
tools=tool_defs,
model=model,
max_tokens=8192,
temperature=self.config.temperature,
)
# Track cost
if self.cost_tracker and response.usage:
await self.cost_tracker.record(model, response.usage)
if not response.has_tool_calls:
return response.content or "(no response)"
content = response.content or "(no response)"
# Cache the final response (only if no tool calls were used in the conversation)
if self.cache and iteration == 0:
await self.cache.put(model, messages, content)
return content
# Add assistant message with tool calls
messages.append(self._assistant_message(response))
@@ -69,6 +137,8 @@ class Engine:
# Execute each tool call
for tc in response.tool_calls:
result = await self._execute_tool(tc.name, tc.arguments)
# Scrub credentials from tool output
result = scrub_credentials(result)
messages.append(
{
"role": "tool",
@@ -83,8 +153,121 @@ class Engine:
f"{len(response.tool_calls)} tool call(s)"
)
# History compaction — summarize old messages if history is too long
messages = await self._compact_history(messages, model)
return "(max iterations reached)"
async def _call_with_retry(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None,
model: str,
) -> LLMResponse:
"""Call the LLM with retry and fallback."""
last_error: Exception | None = None
for attempt in range(_MAX_RETRIES):
try:
return await self.provider.complete(
messages=messages,
tools=tools,
model=model,
max_tokens=8192,
temperature=self.config.temperature,
)
except Exception as e:
last_error = e
error_str = str(e).lower()
# Don't retry non-retryable errors
if any(kw in error_str for kw in _NON_RETRYABLE_KEYWORDS):
logger.error(f"[{self.config.name}] Non-retryable error: {e}")
break
wait = _RETRY_BACKOFF_BASE ** attempt
logger.warning(
f"[{self.config.name}] Attempt {attempt + 1}/{_MAX_RETRIES} failed: {e} — retrying in {wait}s"
)
await asyncio.sleep(wait)
# Try fallback provider
if self.fallback_provider:
logger.info(f"[{self.config.name}] Trying fallback provider")
try:
return await self.fallback_provider.complete(
messages=messages,
tools=tools,
model=self.fallback_provider.get_default_model(),
max_tokens=8192,
temperature=self.config.temperature,
)
except Exception as fallback_err:
logger.error(f"[{self.config.name}] Fallback also failed: {fallback_err}")
# All retries and fallback exhausted
return LLMResponse(
content=f"(error after {_MAX_RETRIES} retries: {last_error})",
finish_reason="error",
)
async def _compact_history(
self,
messages: list[dict[str, Any]],
model: str,
) -> list[dict[str, Any]]:
"""Compact old messages if history exceeds threshold."""
# Count non-system messages
non_system = [m for m in messages if m["role"] != "system"]
if len(non_system) <= _MAX_HISTORY_MESSAGES:
return messages
# Split: keep system messages, summarize old, keep recent
system_msgs = [m for m in messages if m["role"] == "system"]
old_msgs = non_system[:-_KEEP_RECENT]
recent_msgs = non_system[-_KEEP_RECENT:]
# Ask LLM to summarize the old messages
summary_messages = [
{"role": "system", "content": _COMPACTION_PROMPT},
{
"role": "user",
"content": json.dumps(
[{"role": m["role"], "content": m.get("content", "")} for m in old_msgs],
indent=None,
),
},
]
try:
summary_response = await self.provider.complete(
messages=summary_messages,
model=model,
max_tokens=2048,
temperature=0.1,
)
summary_text = summary_response.content or ""
logger.info(
f"[{self.config.name}] Compacted {len(old_msgs)} messages → "
f"{len(summary_text)} char summary"
)
except Exception as e:
logger.warning(f"[{self.config.name}] History compaction failed: {e}")
return messages
# Rebuild: system + summary + recent
compacted = list(system_msgs)
compacted.append({
"role": "user",
"content": f"[Conversation summary]\n{summary_text}",
})
compacted.append({
"role": "assistant",
"content": "Understood, I have the context from our previous conversation.",
})
compacted.extend(recent_msgs)
return compacted
async def _execute_tool(self, name: str, arguments: dict[str, Any]) -> str:
"""Execute a tool with approval check."""
approved = await self.approval.check(name, arguments)

View File

@@ -10,7 +10,10 @@ from typing import Any
from loguru import logger
from xtrm_agent.bus import AgentMessage, InboundMessage, MessageBus, OutboundMessage
from xtrm_agent.cache import ResponseCache
from xtrm_agent.classifier import QueryClassifier
from xtrm_agent.config import Config, AgentFileConfig, parse_agent_file
from xtrm_agent.cost import BudgetConfig, CostTracker
from xtrm_agent.engine import Engine
from xtrm_agent.llm.anthropic import AnthropicProvider
from xtrm_agent.llm.litellm import LiteLLMProvider
@@ -35,6 +38,8 @@ class Orchestrator:
self._agent_configs: dict[str, AgentFileConfig] = {}
self._mcp_stack = AsyncExitStack()
self._running = False
self._cache: ResponseCache | None = None
self._cost_tracker: CostTracker | None = None
# Channel defaults for routing
channel_defaults = {}
@@ -53,6 +58,27 @@ class Orchestrator:
workspace = Path(self.config.tools.workspace).resolve()
workspace.mkdir(parents=True, exist_ok=True)
# Initialize shared response cache
self._cache = ResponseCache(
db_path=workspace / "cache.db",
ttl=self.config.performance.cache_ttl,
)
await self._cache.setup()
# Initialize cost tracker
budget = BudgetConfig(
daily_limit_usd=self.config.performance.daily_budget_usd,
monthly_limit_usd=self.config.performance.monthly_budget_usd,
)
self._cost_tracker = CostTracker(
db_path=workspace / "costs.db",
budget=budget,
)
await self._cost_tracker.setup()
# Initialize query classifier
classifier = QueryClassifier(model_map=self.config.performance.model_routing)
# Parse all agent definitions
for agent_name, agent_path in self.config.agents.items():
p = Path(agent_path)
@@ -74,6 +100,10 @@ class Orchestrator:
await self._mcp_stack.__aenter__()
await connect_mcp_servers(self.config.mcp_servers, global_registry, self._mcp_stack)
# Create fallback provider (LiteLLM with a cheap model)
fallback_model = self.config.performance.fallback_model
fallback_provider = LiteLLMProvider(model=fallback_model) if fallback_model else None
# Create one engine per agent
agent_names = list(self._agent_configs.keys())
for agent_name, agent_cfg in self._agent_configs.items():
@@ -107,6 +137,10 @@ class Orchestrator:
provider=provider,
tools=agent_registry,
approval=approval,
cache=self._cache,
cost_tracker=self._cost_tracker,
classifier=classifier,
fallback_provider=fallback_provider,
)
self._engines[agent_name] = engine
@@ -190,6 +224,10 @@ class Orchestrator:
async def stop(self) -> None:
self._running = False
if self._cache:
await self._cache.close()
if self._cost_tracker:
await self._cost_tracker.close()
await self._mcp_stack.aclose()
logger.info("Orchestrator stopped")

46
xtrm_agent/scrub.py Normal file
View File

@@ -0,0 +1,46 @@
"""Credential scrubbing — prevent secret leakage in tool output."""
from __future__ import annotations
import re
# Patterns that match common secret formats
_SECRET_PATTERNS = [
# Key=value patterns (API keys, tokens, passwords)
re.compile(
r"(?i)(api[_-]?key|token|password|passwd|secret|access[_-]?key|private[_-]?key|auth)"
r"[\s]*[=:]\s*['\"]?([^\s'\"]{8,})['\"]?",
),
# Bearer tokens
re.compile(r"Bearer\s+[A-Za-z0-9\-._~+/]+=*", re.IGNORECASE),
# AWS access keys (AKIA...)
re.compile(r"AKIA[0-9A-Z]{16}"),
# AWS secret keys (40 char base64)
re.compile(r"(?i)aws[_-]?secret[_-]?access[_-]?key[\s]*[=:]\s*['\"]?([A-Za-z0-9/+=]{40})['\"]?"),
# GitHub tokens
re.compile(r"gh[pousr]_[A-Za-z0-9_]{36,}"),
# Generic long hex strings that look like secrets (32+ hex chars after key= or token=)
re.compile(r"(?i)(?:key|token|secret)[=:]\s*['\"]?([0-9a-f]{32,})['\"]?"),
]
_REDACTED = "[REDACTED]"
def scrub_credentials(text: str) -> str:
"""Scrub potential secrets from text, replacing with [REDACTED]."""
result = text
for pattern in _SECRET_PATTERNS:
result = pattern.sub(_redact_match, result)
return result
def _redact_match(match: re.Match) -> str:
"""Replace the secret value while keeping the key name visible."""
full = match.group(0)
# For key=value patterns, keep the key part
for sep in ("=", ":"):
if sep in full:
key_part = full[: full.index(sep) + 1]
return f"{key_part} {_REDACTED}"
# For standalone patterns (Bearer, AKIA, gh*_), redact the whole thing
return _REDACTED

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import time
from enum import Enum
from typing import Any
@@ -15,7 +16,7 @@ class ApprovalPolicy(Enum):
class ApprovalEngine:
"""Deny-by-default tool approval."""
"""Deny-by-default tool approval with session allowlist and audit log."""
def __init__(
self,
@@ -26,6 +27,8 @@ class ApprovalEngine:
self._auto_approve = set(auto_approve or [])
self._require_approval = set(require_approval or [])
self._interactive = interactive
self._session_allowed: set[str] = set()
self._audit_log: list[dict[str, Any]] = []
def get_policy(self, tool_name: str) -> ApprovalPolicy:
"""Get the approval policy for a tool."""
@@ -43,27 +46,54 @@ class ApprovalEngine:
policy = self.get_policy(tool_name)
if policy == ApprovalPolicy.AUTO_APPROVE:
self._log_decision(tool_name, arguments, "auto_approved")
return True
# Session-scoped "always allow"
if tool_name in self._session_allowed:
self._log_decision(tool_name, arguments, "session_allowed")
return True
if policy == ApprovalPolicy.DENY:
logger.warning(f"Tool '{tool_name}' denied by policy")
self._log_decision(tool_name, arguments, "denied")
return False
# REQUIRE_APPROVAL
if not self._interactive:
logger.warning(f"Tool '{tool_name}' requires approval but running non-interactively — denied")
self._log_decision(tool_name, arguments, "denied_non_interactive")
return False
# In interactive mode, prompt the user
logger.info(f"Tool '{tool_name}' requires approval. Args: {arguments}")
return await self._prompt_user(tool_name, arguments)
approved, always = await self._prompt_user(tool_name, arguments)
if approved and always:
self._session_allowed.add(tool_name)
self._log_decision(tool_name, arguments, "user_approved" if approved else "user_denied")
return approved
async def _prompt_user(self, tool_name: str, arguments: dict[str, Any]) -> bool:
"""Prompt user for tool approval (interactive mode)."""
async def _prompt_user(self, tool_name: str, arguments: dict[str, Any]) -> tuple[bool, bool]:
"""Prompt user for tool approval. Returns (approved, always_allow)."""
print(f"\n[APPROVAL REQUIRED] Tool: {tool_name}")
print(f" Arguments: {arguments}")
try:
answer = input(" Allow? [y/N]: ").strip().lower()
return answer in ("y", "yes")
answer = input(" Allow? [y/N/a(lways)]: ").strip().lower()
if answer in ("a", "always"):
return True, True
return answer in ("y", "yes"), False
except (EOFError, KeyboardInterrupt):
return False
return False, False
def _log_decision(self, tool_name: str, arguments: dict[str, Any], decision: str) -> None:
"""Record an approval decision in the audit log."""
self._audit_log.append({
"tool": tool_name,
"arguments": arguments,
"decision": decision,
"timestamp": time.time(),
})
def get_audit_log(self) -> list[dict[str, Any]]:
"""Return the audit log for inspection."""
return list(self._audit_log)