From 872ed24f0c28dd74595d9794215a2803b2e02465 Mon Sep 17 00:00:00 2001 From: Kaloyan Danchev Date: Thu, 19 Feb 2026 09:20:52 +0200 Subject: [PATCH] Add performance features: caching, cost tracking, retry, compaction, classification, scrubbing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Inspired by zeroclaw's lightweight patterns for slow hardware: - Response cache (SQLite + SHA-256 keyed) to skip redundant LLM calls - History compaction — LLM-summarize old messages when history exceeds 50 - Query classifier routes simple/research queries to cheaper models - Credential scrubbing removes secrets from tool output before sending to LLM - Cost tracker with daily/monthly budget enforcement (SQLite) - Resilient provider with retry + exponential backoff + fallback provider - Approval engine gains session "always allow" and audit log Co-Authored-By: Claude Opus 4.6 --- config.yaml | 9 ++ pyproject.toml | 1 + xtrm_agent/cache.py | 100 +++++++++++++++++ xtrm_agent/classifier.py | 86 +++++++++++++++ xtrm_agent/config.py | 11 ++ xtrm_agent/cost.py | 173 +++++++++++++++++++++++++++++ xtrm_agent/engine.py | 203 +++++++++++++++++++++++++++++++++-- xtrm_agent/orchestrator.py | 38 +++++++ xtrm_agent/scrub.py | 46 ++++++++ xtrm_agent/tools/approval.py | 44 ++++++-- 10 files changed, 694 insertions(+), 17 deletions(-) create mode 100644 xtrm_agent/cache.py create mode 100644 xtrm_agent/classifier.py create mode 100644 xtrm_agent/cost.py create mode 100644 xtrm_agent/scrub.py diff --git a/config.yaml b/config.yaml index af8f198..16110a2 100644 --- a/config.yaml +++ b/config.yaml @@ -43,3 +43,12 @@ agents: orchestrator: max_concurrent: 5 delegation_timeout: 120 + +performance: + cache_ttl: 3600 + daily_budget_usd: 5.0 + monthly_budget_usd: 50.0 + fallback_model: nvidia_nim/deepseek-ai/deepseek-v3.1 + model_routing: + simple: nvidia_nim/deepseek-ai/deepseek-v3.1 + research: nvidia_nim/moonshotai/kimi-k2.5 diff --git a/pyproject.toml b/pyproject.toml index 60af028..2984283 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "json-repair>=0.30.0", "duckduckgo-search>=7.0.0", "pypdf>=5.0.0", + "aiosqlite>=0.20.0", ] [project.scripts] diff --git a/xtrm_agent/cache.py b/xtrm_agent/cache.py new file mode 100644 index 0000000..c691200 --- /dev/null +++ b/xtrm_agent/cache.py @@ -0,0 +1,100 @@ +"""Response cache — avoid redundant LLM calls for identical prompts.""" + +from __future__ import annotations + +import hashlib +import json +import time +from pathlib import Path + +import aiosqlite +from loguru import logger + + +class ResponseCache: + """SQLite-backed LLM response cache with TTL expiry.""" + + def __init__(self, db_path: str | Path = "data/cache.db", ttl: int = 3600) -> None: + self._db_path = str(db_path) + self._ttl = ttl + self._db: aiosqlite.Connection | None = None + + async def setup(self) -> None: + """Create the cache table if it doesn't exist.""" + Path(self._db_path).parent.mkdir(parents=True, exist_ok=True) + self._db = await aiosqlite.connect(self._db_path) + await self._db.execute("PRAGMA journal_mode=WAL") + await self._db.execute( + """ + CREATE TABLE IF NOT EXISTS response_cache ( + key TEXT PRIMARY KEY, + response TEXT NOT NULL, + model TEXT NOT NULL, + created_at REAL NOT NULL, + hits INTEGER DEFAULT 0 + ) + """ + ) + await self._db.commit() + + async def close(self) -> None: + if self._db: + await self._db.close() + self._db = None + + @staticmethod + def _make_key(model: str, messages: list[dict]) -> str: + """SHA-256 hash of model + messages for cache key.""" + raw = json.dumps({"model": model, "messages": messages}, sort_keys=True) + return hashlib.sha256(raw.encode()).hexdigest() + + async def get(self, model: str, messages: list[dict]) -> str | None: + """Look up a cached response. Returns None on miss or expired.""" + if not self._db: + return None + key = self._make_key(model, messages) + now = time.time() + async with self._db.execute( + "SELECT response, created_at FROM response_cache WHERE key = ?", + (key,), + ) as cursor: + row = await cursor.fetchone() + if not row: + return None + response, created_at = row + if now - created_at > self._ttl: + await self._db.execute("DELETE FROM response_cache WHERE key = ?", (key,)) + await self._db.commit() + return None + # Bump hit count + await self._db.execute( + "UPDATE response_cache SET hits = hits + 1 WHERE key = ?", (key,) + ) + await self._db.commit() + logger.debug(f"Cache hit for {model} (key={key[:12]}...)") + return response + + async def put(self, model: str, messages: list[dict], response: str) -> None: + """Store a response in the cache.""" + if not self._db: + return + key = self._make_key(model, messages) + await self._db.execute( + """ + INSERT OR REPLACE INTO response_cache (key, response, model, created_at, hits) + VALUES (?, ?, ?, ?, 0) + """, + (key, response, model, time.time()), + ) + await self._db.commit() + + async def clear_expired(self) -> int: + """Remove expired entries. Returns count of deleted rows.""" + if not self._db: + return 0 + cutoff = time.time() - self._ttl + cursor = await self._db.execute( + "DELETE FROM response_cache WHERE created_at < ?", (cutoff,) + ) + await self._db.commit() + return cursor.rowcount diff --git a/xtrm_agent/classifier.py b/xtrm_agent/classifier.py new file mode 100644 index 0000000..0ee826d --- /dev/null +++ b/xtrm_agent/classifier.py @@ -0,0 +1,86 @@ +"""Query classifier — route queries to appropriate models.""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Any + + +@dataclass +class ClassificationResult: + """Result of classifying a query.""" + + category: str # "simple", "code", "research", "complex" + model_hint: str # suggested model override, or "" to use default + + +# Short, trivial queries — greetings, thanks, yes/no, etc. +_SIMPLE_PATTERNS = [ + re.compile(r"^(hi|hello|hey|thanks|thank you|ok|yes|no|sure|bye|good\s*(morning|evening|night))[\s!.?]*$", re.IGNORECASE), + re.compile(r"^(what time|what date|what day)[\s\S]{0,20}\??\s*$", re.IGNORECASE), +] + +# Code-related queries +_CODE_PATTERNS = [ + re.compile(r"\b(function|class|def|import|const|let|var|return|async|await)\b"), + re.compile(r"\b(refactor|debug|fix|implement|code|compile|build|test|lint|deploy)\b", re.IGNORECASE), + re.compile(r"```"), + re.compile(r"\.(py|js|ts|rs|go|java|cpp|c|rb|php|swift|kt)\b"), +] + +# Research / information queries +_RESEARCH_PATTERNS = [ + re.compile(r"\b(search|find|look\s*up|research|summarize|explain|compare|analyze)\b", re.IGNORECASE), + re.compile(r"\b(what is|how does|why does|can you explain)\b", re.IGNORECASE), +] + + +class QueryClassifier: + """Classify queries for model routing. + + Configure with a mapping of category → model override. + Categories: "simple", "code", "research", "complex". + """ + + def __init__(self, model_map: dict[str, str] | None = None) -> None: + self._model_map = model_map or {} + + def classify(self, query: str) -> ClassificationResult: + """Classify a query and suggest a model.""" + query_stripped = query.strip() + + # Very short queries are likely simple + if len(query_stripped) < 20: + for pattern in _SIMPLE_PATTERNS: + if pattern.match(query_stripped): + return ClassificationResult( + category="simple", + model_hint=self._model_map.get("simple", ""), + ) + + # Code patterns + code_score = sum(1 for p in _CODE_PATTERNS if p.search(query_stripped)) + if code_score >= 2: + return ClassificationResult( + category="code", + model_hint=self._model_map.get("code", ""), + ) + + # Research patterns + research_score = sum(1 for p in _RESEARCH_PATTERNS if p.search(query_stripped)) + if research_score >= 1: + return ClassificationResult( + category="research", + model_hint=self._model_map.get("research", ""), + ) + + # Long or complex queries + if len(query_stripped) > 500 or query_stripped.count("\n") > 5: + return ClassificationResult( + category="complex", + model_hint=self._model_map.get("complex", ""), + ) + + # Default — no override + return ClassificationResult(category="complex", model_hint="") diff --git a/xtrm_agent/config.py b/xtrm_agent/config.py index 2ab4335..852c29e 100644 --- a/xtrm_agent/config.py +++ b/xtrm_agent/config.py @@ -61,6 +61,16 @@ class MCPServerConfig(BaseModel): url: str = "" +class PerformanceConfig(BaseModel): + """Performance tuning — caching, cost tracking, model routing.""" + + cache_ttl: int = 3600 + daily_budget_usd: float = 0.0 + monthly_budget_usd: float = 0.0 + fallback_model: str = "" + model_routing: dict[str, str] = Field(default_factory=dict) + + class OrchestratorConfig(BaseModel): max_concurrent: int = 5 delegation_timeout: int = 120 @@ -87,6 +97,7 @@ class Config(BaseModel): mcp_servers: dict[str, MCPServerConfig] = Field(default_factory=dict) agents: dict[str, str] = Field(default_factory=dict) orchestrator: OrchestratorConfig = Field(default_factory=OrchestratorConfig) + performance: PerformanceConfig = Field(default_factory=PerformanceConfig) def load_config(path: str | Path = "config.yaml") -> Config: diff --git a/xtrm_agent/cost.py b/xtrm_agent/cost.py new file mode 100644 index 0000000..128d6fc --- /dev/null +++ b/xtrm_agent/cost.py @@ -0,0 +1,173 @@ +"""Cost tracking — monitor token usage and enforce budget limits.""" + +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from pathlib import Path + +import aiosqlite +from loguru import logger + + +# Pricing per 1M tokens (input, output) in USD +MODEL_PRICING: dict[str, tuple[float, float]] = { + # Anthropic + "claude-sonnet-4-5-20250929": (3.0, 15.0), + "claude-haiku-4-5-20251001": (0.80, 4.0), + "claude-opus-4-6": (15.0, 75.0), + # NVIDIA NIM / LiteLLM + "nvidia_nim/deepseek-ai/deepseek-v3.1": (0.27, 1.10), + "nvidia_nim/moonshotai/kimi-k2.5": (0.35, 1.40), + "nvidia_nim/minimaxai/minimax-m2.1": (0.30, 1.10), + # Fallback + "_default": (1.0, 3.0), +} + + +def _get_pricing(model: str) -> tuple[float, float]: + """Get (input_price, output_price) per 1M tokens for a model.""" + return MODEL_PRICING.get(model, MODEL_PRICING["_default"]) + + +@dataclass +class CostRecord: + model: str + input_tokens: int + output_tokens: int + cost_usd: float + timestamp: float + + +@dataclass +class BudgetConfig: + daily_limit_usd: float = 0.0 # 0 = no limit + monthly_limit_usd: float = 0.0 # 0 = no limit + warning_threshold: float = 0.8 # warn at 80% of budget + + +class CostTracker: + """Track LLM costs in SQLite and enforce budget limits.""" + + def __init__( + self, + db_path: str | Path = "data/costs.db", + budget: BudgetConfig | None = None, + ) -> None: + self._db_path = str(db_path) + self._budget = budget or BudgetConfig() + self._db: aiosqlite.Connection | None = None + + async def setup(self) -> None: + Path(self._db_path).parent.mkdir(parents=True, exist_ok=True) + self._db = await aiosqlite.connect(self._db_path) + await self._db.execute("PRAGMA journal_mode=WAL") + await self._db.execute( + """ + CREATE TABLE IF NOT EXISTS cost_records ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + model TEXT NOT NULL, + input_tokens INTEGER NOT NULL, + output_tokens INTEGER NOT NULL, + cost_usd REAL NOT NULL, + timestamp REAL NOT NULL + ) + """ + ) + await self._db.commit() + + async def close(self) -> None: + if self._db: + await self._db.close() + self._db = None + + async def record(self, model: str, usage: dict[str, int]) -> CostRecord: + """Record token usage and return the cost record.""" + input_tokens = usage.get("input_tokens", 0) + output_tokens = usage.get("output_tokens", 0) + input_price, output_price = _get_pricing(model) + cost = (input_tokens * input_price + output_tokens * output_price) / 1_000_000 + + record = CostRecord( + model=model, + input_tokens=input_tokens, + output_tokens=output_tokens, + cost_usd=cost, + timestamp=time.time(), + ) + + if self._db: + await self._db.execute( + "INSERT INTO cost_records (model, input_tokens, output_tokens, cost_usd, timestamp) VALUES (?, ?, ?, ?, ?)", + (model, input_tokens, output_tokens, cost, record.timestamp), + ) + await self._db.commit() + + logger.debug( + f"Cost: {model} — {input_tokens}in/{output_tokens}out = ${cost:.6f}" + ) + return record + + async def check_budget(self) -> str: + """Check budget status. Returns 'ok', 'warning', or 'exceeded'.""" + if not self._db: + return "ok" + + now = time.time() + + # Daily check + if self._budget.daily_limit_usd > 0: + day_start = now - (now % 86400) + daily_total = await self._sum_costs_since(day_start) + if daily_total >= self._budget.daily_limit_usd: + logger.warning(f"Daily budget EXCEEDED: ${daily_total:.4f} / ${self._budget.daily_limit_usd:.2f}") + return "exceeded" + if daily_total >= self._budget.daily_limit_usd * self._budget.warning_threshold: + logger.warning(f"Daily budget warning: ${daily_total:.4f} / ${self._budget.daily_limit_usd:.2f}") + return "warning" + + # Monthly check + if self._budget.monthly_limit_usd > 0: + month_start = now - (now % (86400 * 30)) + monthly_total = await self._sum_costs_since(month_start) + if monthly_total >= self._budget.monthly_limit_usd: + logger.warning(f"Monthly budget EXCEEDED: ${monthly_total:.4f} / ${self._budget.monthly_limit_usd:.2f}") + return "exceeded" + if monthly_total >= self._budget.monthly_limit_usd * self._budget.warning_threshold: + logger.warning(f"Monthly budget warning: ${monthly_total:.4f} / ${self._budget.monthly_limit_usd:.2f}") + return "warning" + + return "ok" + + async def get_summary(self) -> dict: + """Get a cost summary with per-model breakdown.""" + if not self._db: + return {"total_usd": 0, "models": {}} + + now = time.time() + day_start = now - (now % 86400) + + rows: list = [] + async with self._db.execute( + "SELECT model, SUM(input_tokens), SUM(output_tokens), SUM(cost_usd) FROM cost_records WHERE timestamp >= ? GROUP BY model", + (day_start,), + ) as cursor: + rows = await cursor.fetchall() + + models = {} + total = 0.0 + for model, inp, out, cost in rows: + models[model] = {"input_tokens": inp, "output_tokens": out, "cost_usd": cost} + total += cost + + return {"total_usd": total, "models": models} + + async def _sum_costs_since(self, since: float) -> float: + if not self._db: + return 0.0 + async with self._db.execute( + "SELECT COALESCE(SUM(cost_usd), 0) FROM cost_records WHERE timestamp >= ?", + (since,), + ) as cursor: + row = await cursor.fetchone() + return row[0] if row else 0.0 diff --git a/xtrm_agent/engine.py b/xtrm_agent/engine.py index e3eed59..c3938a0 100644 --- a/xtrm_agent/engine.py +++ b/xtrm_agent/engine.py @@ -2,17 +2,45 @@ from __future__ import annotations +import asyncio +import json from typing import Any from loguru import logger +from xtrm_agent.cache import ResponseCache +from xtrm_agent.classifier import ClassificationResult, QueryClassifier from xtrm_agent.config import AgentFileConfig -import json - +from xtrm_agent.cost import CostTracker from xtrm_agent.llm.provider import LLMProvider, LLMResponse +from xtrm_agent.scrub import scrub_credentials from xtrm_agent.tools.approval import ApprovalEngine from xtrm_agent.tools.registry import ToolRegistry +# History compaction settings +_MAX_HISTORY_MESSAGES = 50 +_KEEP_RECENT = 20 +_COMPACTION_PROMPT = ( + "Summarize the following conversation history concisely. " + "Preserve key facts, decisions, tool results, and context needed to continue. " + "Be brief but complete." +) + +# Retry settings +_MAX_RETRIES = 3 +_RETRY_BACKOFF_BASE = 2.0 +_NON_RETRYABLE_KEYWORDS = [ + "authentication", + "unauthorized", + "invalid api key", + "model not found", + "permission denied", + "invalid_api_key", + "401", + "403", + "404", +] + class Engine: """Runs one agent's LLM loop: messages → LLM → tool calls → loop → response.""" @@ -23,16 +51,33 @@ class Engine: provider: LLMProvider, tools: ToolRegistry, approval: ApprovalEngine, + cache: ResponseCache | None = None, + cost_tracker: CostTracker | None = None, + classifier: QueryClassifier | None = None, + fallback_provider: LLMProvider | None = None, ) -> None: self.config = agent_config self.provider = provider self.tools = tools self.approval = approval + self.cache = cache + self.cost_tracker = cost_tracker + self.classifier = classifier + self.fallback_provider = fallback_provider async def run(self, user_message: str) -> str: """Process a single user message through the agent loop.""" messages = self._build_initial_messages(user_message) - return await self._agent_loop(messages) + + # Query classification — override model if classifier suggests one + model_override = "" + if self.classifier: + result = self.classifier.classify(user_message) + if result.model_hint: + model_override = result.model_hint + logger.debug(f"[{self.config.name}] Classified as '{result.category}' → model={result.model_hint}") + + return await self._agent_loop(messages, model_override=model_override) async def run_delegation(self, task: str) -> str: """Process a delegation task (no system prompt changes).""" @@ -46,22 +91,45 @@ class Engine: messages.append({"role": "user", "content": user_message}) return messages - async def _agent_loop(self, messages: list[dict[str, Any]]) -> str: - """Core agent iteration loop.""" + async def _agent_loop( + self, + messages: list[dict[str, Any]], + model_override: str = "", + ) -> str: + """Core agent iteration loop with caching, cost tracking, retry, and compaction.""" for iteration in range(self.config.max_iterations): - model = self.config.model or self.provider.get_default_model() + model = model_override or self.config.model or self.provider.get_default_model() tool_defs = self.tools.get_definitions() if self.tools.names() else None - response = await self.provider.complete( + # Budget check + if self.cost_tracker: + status = await self.cost_tracker.check_budget() + if status == "exceeded": + return "(budget exceeded — please try again later)" + + # Response cache check (only for first iteration — no tool calls yet) + if self.cache and iteration == 0 and not tool_defs: + cached = await self.cache.get(model, messages) + if cached: + return cached + + # LLM call with retry + fallback + response = await self._call_with_retry( messages=messages, tools=tool_defs, model=model, - max_tokens=8192, - temperature=self.config.temperature, ) + # Track cost + if self.cost_tracker and response.usage: + await self.cost_tracker.record(model, response.usage) + if not response.has_tool_calls: - return response.content or "(no response)" + content = response.content or "(no response)" + # Cache the final response (only if no tool calls were used in the conversation) + if self.cache and iteration == 0: + await self.cache.put(model, messages, content) + return content # Add assistant message with tool calls messages.append(self._assistant_message(response)) @@ -69,6 +137,8 @@ class Engine: # Execute each tool call for tc in response.tool_calls: result = await self._execute_tool(tc.name, tc.arguments) + # Scrub credentials from tool output + result = scrub_credentials(result) messages.append( { "role": "tool", @@ -83,8 +153,121 @@ class Engine: f"{len(response.tool_calls)} tool call(s)" ) + # History compaction — summarize old messages if history is too long + messages = await self._compact_history(messages, model) + return "(max iterations reached)" + async def _call_with_retry( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None, + model: str, + ) -> LLMResponse: + """Call the LLM with retry and fallback.""" + last_error: Exception | None = None + + for attempt in range(_MAX_RETRIES): + try: + return await self.provider.complete( + messages=messages, + tools=tools, + model=model, + max_tokens=8192, + temperature=self.config.temperature, + ) + except Exception as e: + last_error = e + error_str = str(e).lower() + + # Don't retry non-retryable errors + if any(kw in error_str for kw in _NON_RETRYABLE_KEYWORDS): + logger.error(f"[{self.config.name}] Non-retryable error: {e}") + break + + wait = _RETRY_BACKOFF_BASE ** attempt + logger.warning( + f"[{self.config.name}] Attempt {attempt + 1}/{_MAX_RETRIES} failed: {e} — retrying in {wait}s" + ) + await asyncio.sleep(wait) + + # Try fallback provider + if self.fallback_provider: + logger.info(f"[{self.config.name}] Trying fallback provider") + try: + return await self.fallback_provider.complete( + messages=messages, + tools=tools, + model=self.fallback_provider.get_default_model(), + max_tokens=8192, + temperature=self.config.temperature, + ) + except Exception as fallback_err: + logger.error(f"[{self.config.name}] Fallback also failed: {fallback_err}") + + # All retries and fallback exhausted + return LLMResponse( + content=f"(error after {_MAX_RETRIES} retries: {last_error})", + finish_reason="error", + ) + + async def _compact_history( + self, + messages: list[dict[str, Any]], + model: str, + ) -> list[dict[str, Any]]: + """Compact old messages if history exceeds threshold.""" + # Count non-system messages + non_system = [m for m in messages if m["role"] != "system"] + if len(non_system) <= _MAX_HISTORY_MESSAGES: + return messages + + # Split: keep system messages, summarize old, keep recent + system_msgs = [m for m in messages if m["role"] == "system"] + old_msgs = non_system[:-_KEEP_RECENT] + recent_msgs = non_system[-_KEEP_RECENT:] + + # Ask LLM to summarize the old messages + summary_messages = [ + {"role": "system", "content": _COMPACTION_PROMPT}, + { + "role": "user", + "content": json.dumps( + [{"role": m["role"], "content": m.get("content", "")} for m in old_msgs], + indent=None, + ), + }, + ] + + try: + summary_response = await self.provider.complete( + messages=summary_messages, + model=model, + max_tokens=2048, + temperature=0.1, + ) + summary_text = summary_response.content or "" + logger.info( + f"[{self.config.name}] Compacted {len(old_msgs)} messages → " + f"{len(summary_text)} char summary" + ) + except Exception as e: + logger.warning(f"[{self.config.name}] History compaction failed: {e}") + return messages + + # Rebuild: system + summary + recent + compacted = list(system_msgs) + compacted.append({ + "role": "user", + "content": f"[Conversation summary]\n{summary_text}", + }) + compacted.append({ + "role": "assistant", + "content": "Understood, I have the context from our previous conversation.", + }) + compacted.extend(recent_msgs) + return compacted + async def _execute_tool(self, name: str, arguments: dict[str, Any]) -> str: """Execute a tool with approval check.""" approved = await self.approval.check(name, arguments) diff --git a/xtrm_agent/orchestrator.py b/xtrm_agent/orchestrator.py index ff4a6aa..7af4336 100644 --- a/xtrm_agent/orchestrator.py +++ b/xtrm_agent/orchestrator.py @@ -10,7 +10,10 @@ from typing import Any from loguru import logger from xtrm_agent.bus import AgentMessage, InboundMessage, MessageBus, OutboundMessage +from xtrm_agent.cache import ResponseCache +from xtrm_agent.classifier import QueryClassifier from xtrm_agent.config import Config, AgentFileConfig, parse_agent_file +from xtrm_agent.cost import BudgetConfig, CostTracker from xtrm_agent.engine import Engine from xtrm_agent.llm.anthropic import AnthropicProvider from xtrm_agent.llm.litellm import LiteLLMProvider @@ -35,6 +38,8 @@ class Orchestrator: self._agent_configs: dict[str, AgentFileConfig] = {} self._mcp_stack = AsyncExitStack() self._running = False + self._cache: ResponseCache | None = None + self._cost_tracker: CostTracker | None = None # Channel defaults for routing channel_defaults = {} @@ -53,6 +58,27 @@ class Orchestrator: workspace = Path(self.config.tools.workspace).resolve() workspace.mkdir(parents=True, exist_ok=True) + # Initialize shared response cache + self._cache = ResponseCache( + db_path=workspace / "cache.db", + ttl=self.config.performance.cache_ttl, + ) + await self._cache.setup() + + # Initialize cost tracker + budget = BudgetConfig( + daily_limit_usd=self.config.performance.daily_budget_usd, + monthly_limit_usd=self.config.performance.monthly_budget_usd, + ) + self._cost_tracker = CostTracker( + db_path=workspace / "costs.db", + budget=budget, + ) + await self._cost_tracker.setup() + + # Initialize query classifier + classifier = QueryClassifier(model_map=self.config.performance.model_routing) + # Parse all agent definitions for agent_name, agent_path in self.config.agents.items(): p = Path(agent_path) @@ -74,6 +100,10 @@ class Orchestrator: await self._mcp_stack.__aenter__() await connect_mcp_servers(self.config.mcp_servers, global_registry, self._mcp_stack) + # Create fallback provider (LiteLLM with a cheap model) + fallback_model = self.config.performance.fallback_model + fallback_provider = LiteLLMProvider(model=fallback_model) if fallback_model else None + # Create one engine per agent agent_names = list(self._agent_configs.keys()) for agent_name, agent_cfg in self._agent_configs.items(): @@ -107,6 +137,10 @@ class Orchestrator: provider=provider, tools=agent_registry, approval=approval, + cache=self._cache, + cost_tracker=self._cost_tracker, + classifier=classifier, + fallback_provider=fallback_provider, ) self._engines[agent_name] = engine @@ -190,6 +224,10 @@ class Orchestrator: async def stop(self) -> None: self._running = False + if self._cache: + await self._cache.close() + if self._cost_tracker: + await self._cost_tracker.close() await self._mcp_stack.aclose() logger.info("Orchestrator stopped") diff --git a/xtrm_agent/scrub.py b/xtrm_agent/scrub.py new file mode 100644 index 0000000..cb9e4fb --- /dev/null +++ b/xtrm_agent/scrub.py @@ -0,0 +1,46 @@ +"""Credential scrubbing — prevent secret leakage in tool output.""" + +from __future__ import annotations + +import re + +# Patterns that match common secret formats +_SECRET_PATTERNS = [ + # Key=value patterns (API keys, tokens, passwords) + re.compile( + r"(?i)(api[_-]?key|token|password|passwd|secret|access[_-]?key|private[_-]?key|auth)" + r"[\s]*[=:]\s*['\"]?([^\s'\"]{8,})['\"]?", + ), + # Bearer tokens + re.compile(r"Bearer\s+[A-Za-z0-9\-._~+/]+=*", re.IGNORECASE), + # AWS access keys (AKIA...) + re.compile(r"AKIA[0-9A-Z]{16}"), + # AWS secret keys (40 char base64) + re.compile(r"(?i)aws[_-]?secret[_-]?access[_-]?key[\s]*[=:]\s*['\"]?([A-Za-z0-9/+=]{40})['\"]?"), + # GitHub tokens + re.compile(r"gh[pousr]_[A-Za-z0-9_]{36,}"), + # Generic long hex strings that look like secrets (32+ hex chars after key= or token=) + re.compile(r"(?i)(?:key|token|secret)[=:]\s*['\"]?([0-9a-f]{32,})['\"]?"), +] + +_REDACTED = "[REDACTED]" + + +def scrub_credentials(text: str) -> str: + """Scrub potential secrets from text, replacing with [REDACTED].""" + result = text + for pattern in _SECRET_PATTERNS: + result = pattern.sub(_redact_match, result) + return result + + +def _redact_match(match: re.Match) -> str: + """Replace the secret value while keeping the key name visible.""" + full = match.group(0) + # For key=value patterns, keep the key part + for sep in ("=", ":"): + if sep in full: + key_part = full[: full.index(sep) + 1] + return f"{key_part} {_REDACTED}" + # For standalone patterns (Bearer, AKIA, gh*_), redact the whole thing + return _REDACTED diff --git a/xtrm_agent/tools/approval.py b/xtrm_agent/tools/approval.py index 25220b9..83cfad4 100644 --- a/xtrm_agent/tools/approval.py +++ b/xtrm_agent/tools/approval.py @@ -2,6 +2,7 @@ from __future__ import annotations +import time from enum import Enum from typing import Any @@ -15,7 +16,7 @@ class ApprovalPolicy(Enum): class ApprovalEngine: - """Deny-by-default tool approval.""" + """Deny-by-default tool approval with session allowlist and audit log.""" def __init__( self, @@ -26,6 +27,8 @@ class ApprovalEngine: self._auto_approve = set(auto_approve or []) self._require_approval = set(require_approval or []) self._interactive = interactive + self._session_allowed: set[str] = set() + self._audit_log: list[dict[str, Any]] = [] def get_policy(self, tool_name: str) -> ApprovalPolicy: """Get the approval policy for a tool.""" @@ -43,27 +46,54 @@ class ApprovalEngine: policy = self.get_policy(tool_name) if policy == ApprovalPolicy.AUTO_APPROVE: + self._log_decision(tool_name, arguments, "auto_approved") + return True + + # Session-scoped "always allow" + if tool_name in self._session_allowed: + self._log_decision(tool_name, arguments, "session_allowed") return True if policy == ApprovalPolicy.DENY: logger.warning(f"Tool '{tool_name}' denied by policy") + self._log_decision(tool_name, arguments, "denied") return False # REQUIRE_APPROVAL if not self._interactive: logger.warning(f"Tool '{tool_name}' requires approval but running non-interactively — denied") + self._log_decision(tool_name, arguments, "denied_non_interactive") return False # In interactive mode, prompt the user logger.info(f"Tool '{tool_name}' requires approval. Args: {arguments}") - return await self._prompt_user(tool_name, arguments) + approved, always = await self._prompt_user(tool_name, arguments) + if approved and always: + self._session_allowed.add(tool_name) + self._log_decision(tool_name, arguments, "user_approved" if approved else "user_denied") + return approved - async def _prompt_user(self, tool_name: str, arguments: dict[str, Any]) -> bool: - """Prompt user for tool approval (interactive mode).""" + async def _prompt_user(self, tool_name: str, arguments: dict[str, Any]) -> tuple[bool, bool]: + """Prompt user for tool approval. Returns (approved, always_allow).""" print(f"\n[APPROVAL REQUIRED] Tool: {tool_name}") print(f" Arguments: {arguments}") try: - answer = input(" Allow? [y/N]: ").strip().lower() - return answer in ("y", "yes") + answer = input(" Allow? [y/N/a(lways)]: ").strip().lower() + if answer in ("a", "always"): + return True, True + return answer in ("y", "yes"), False except (EOFError, KeyboardInterrupt): - return False + return False, False + + def _log_decision(self, tool_name: str, arguments: dict[str, Any], decision: str) -> None: + """Record an approval decision in the audit log.""" + self._audit_log.append({ + "tool": tool_name, + "arguments": arguments, + "decision": decision, + "timestamp": time.time(), + }) + + def get_audit_log(self) -> list[dict[str, Any]]: + """Return the audit log for inspection.""" + return list(self._audit_log)