Add performance features: caching, cost tracking, retry, compaction, classification, scrubbing
Inspired by zeroclaw's lightweight patterns for slow hardware: - Response cache (SQLite + SHA-256 keyed) to skip redundant LLM calls - History compaction — LLM-summarize old messages when history exceeds 50 - Query classifier routes simple/research queries to cheaper models - Credential scrubbing removes secrets from tool output before sending to LLM - Cost tracker with daily/monthly budget enforcement (SQLite) - Resilient provider with retry + exponential backoff + fallback provider - Approval engine gains session "always allow" and audit log Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -43,3 +43,12 @@ agents:
|
|||||||
orchestrator:
|
orchestrator:
|
||||||
max_concurrent: 5
|
max_concurrent: 5
|
||||||
delegation_timeout: 120
|
delegation_timeout: 120
|
||||||
|
|
||||||
|
performance:
|
||||||
|
cache_ttl: 3600
|
||||||
|
daily_budget_usd: 5.0
|
||||||
|
monthly_budget_usd: 50.0
|
||||||
|
fallback_model: nvidia_nim/deepseek-ai/deepseek-v3.1
|
||||||
|
model_routing:
|
||||||
|
simple: nvidia_nim/deepseek-ai/deepseek-v3.1
|
||||||
|
research: nvidia_nim/moonshotai/kimi-k2.5
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ dependencies = [
|
|||||||
"json-repair>=0.30.0",
|
"json-repair>=0.30.0",
|
||||||
"duckduckgo-search>=7.0.0",
|
"duckduckgo-search>=7.0.0",
|
||||||
"pypdf>=5.0.0",
|
"pypdf>=5.0.0",
|
||||||
|
"aiosqlite>=0.20.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
100
xtrm_agent/cache.py
Normal file
100
xtrm_agent/cache.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
"""Response cache — avoid redundant LLM calls for identical prompts."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseCache:
|
||||||
|
"""SQLite-backed LLM response cache with TTL expiry."""
|
||||||
|
|
||||||
|
def __init__(self, db_path: str | Path = "data/cache.db", ttl: int = 3600) -> None:
|
||||||
|
self._db_path = str(db_path)
|
||||||
|
self._ttl = ttl
|
||||||
|
self._db: aiosqlite.Connection | None = None
|
||||||
|
|
||||||
|
async def setup(self) -> None:
|
||||||
|
"""Create the cache table if it doesn't exist."""
|
||||||
|
Path(self._db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._db = await aiosqlite.connect(self._db_path)
|
||||||
|
await self._db.execute("PRAGMA journal_mode=WAL")
|
||||||
|
await self._db.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS response_cache (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
response TEXT NOT NULL,
|
||||||
|
model TEXT NOT NULL,
|
||||||
|
created_at REAL NOT NULL,
|
||||||
|
hits INTEGER DEFAULT 0
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
await self._db.commit()
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
if self._db:
|
||||||
|
await self._db.close()
|
||||||
|
self._db = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_key(model: str, messages: list[dict]) -> str:
|
||||||
|
"""SHA-256 hash of model + messages for cache key."""
|
||||||
|
raw = json.dumps({"model": model, "messages": messages}, sort_keys=True)
|
||||||
|
return hashlib.sha256(raw.encode()).hexdigest()
|
||||||
|
|
||||||
|
async def get(self, model: str, messages: list[dict]) -> str | None:
|
||||||
|
"""Look up a cached response. Returns None on miss or expired."""
|
||||||
|
if not self._db:
|
||||||
|
return None
|
||||||
|
key = self._make_key(model, messages)
|
||||||
|
now = time.time()
|
||||||
|
async with self._db.execute(
|
||||||
|
"SELECT response, created_at FROM response_cache WHERE key = ?",
|
||||||
|
(key,),
|
||||||
|
) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if not row:
|
||||||
|
return None
|
||||||
|
response, created_at = row
|
||||||
|
if now - created_at > self._ttl:
|
||||||
|
await self._db.execute("DELETE FROM response_cache WHERE key = ?", (key,))
|
||||||
|
await self._db.commit()
|
||||||
|
return None
|
||||||
|
# Bump hit count
|
||||||
|
await self._db.execute(
|
||||||
|
"UPDATE response_cache SET hits = hits + 1 WHERE key = ?", (key,)
|
||||||
|
)
|
||||||
|
await self._db.commit()
|
||||||
|
logger.debug(f"Cache hit for {model} (key={key[:12]}...)")
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def put(self, model: str, messages: list[dict], response: str) -> None:
|
||||||
|
"""Store a response in the cache."""
|
||||||
|
if not self._db:
|
||||||
|
return
|
||||||
|
key = self._make_key(model, messages)
|
||||||
|
await self._db.execute(
|
||||||
|
"""
|
||||||
|
INSERT OR REPLACE INTO response_cache (key, response, model, created_at, hits)
|
||||||
|
VALUES (?, ?, ?, ?, 0)
|
||||||
|
""",
|
||||||
|
(key, response, model, time.time()),
|
||||||
|
)
|
||||||
|
await self._db.commit()
|
||||||
|
|
||||||
|
async def clear_expired(self) -> int:
|
||||||
|
"""Remove expired entries. Returns count of deleted rows."""
|
||||||
|
if not self._db:
|
||||||
|
return 0
|
||||||
|
cutoff = time.time() - self._ttl
|
||||||
|
cursor = await self._db.execute(
|
||||||
|
"DELETE FROM response_cache WHERE created_at < ?", (cutoff,)
|
||||||
|
)
|
||||||
|
await self._db.commit()
|
||||||
|
return cursor.rowcount
|
||||||
86
xtrm_agent/classifier.py
Normal file
86
xtrm_agent/classifier.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
"""Query classifier — route queries to appropriate models."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClassificationResult:
|
||||||
|
"""Result of classifying a query."""
|
||||||
|
|
||||||
|
category: str # "simple", "code", "research", "complex"
|
||||||
|
model_hint: str # suggested model override, or "" to use default
|
||||||
|
|
||||||
|
|
||||||
|
# Short, trivial queries — greetings, thanks, yes/no, etc.
|
||||||
|
_SIMPLE_PATTERNS = [
|
||||||
|
re.compile(r"^(hi|hello|hey|thanks|thank you|ok|yes|no|sure|bye|good\s*(morning|evening|night))[\s!.?]*$", re.IGNORECASE),
|
||||||
|
re.compile(r"^(what time|what date|what day)[\s\S]{0,20}\??\s*$", re.IGNORECASE),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Code-related queries
|
||||||
|
_CODE_PATTERNS = [
|
||||||
|
re.compile(r"\b(function|class|def|import|const|let|var|return|async|await)\b"),
|
||||||
|
re.compile(r"\b(refactor|debug|fix|implement|code|compile|build|test|lint|deploy)\b", re.IGNORECASE),
|
||||||
|
re.compile(r"```"),
|
||||||
|
re.compile(r"\.(py|js|ts|rs|go|java|cpp|c|rb|php|swift|kt)\b"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Research / information queries
|
||||||
|
_RESEARCH_PATTERNS = [
|
||||||
|
re.compile(r"\b(search|find|look\s*up|research|summarize|explain|compare|analyze)\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\b(what is|how does|why does|can you explain)\b", re.IGNORECASE),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class QueryClassifier:
|
||||||
|
"""Classify queries for model routing.
|
||||||
|
|
||||||
|
Configure with a mapping of category → model override.
|
||||||
|
Categories: "simple", "code", "research", "complex".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_map: dict[str, str] | None = None) -> None:
|
||||||
|
self._model_map = model_map or {}
|
||||||
|
|
||||||
|
def classify(self, query: str) -> ClassificationResult:
|
||||||
|
"""Classify a query and suggest a model."""
|
||||||
|
query_stripped = query.strip()
|
||||||
|
|
||||||
|
# Very short queries are likely simple
|
||||||
|
if len(query_stripped) < 20:
|
||||||
|
for pattern in _SIMPLE_PATTERNS:
|
||||||
|
if pattern.match(query_stripped):
|
||||||
|
return ClassificationResult(
|
||||||
|
category="simple",
|
||||||
|
model_hint=self._model_map.get("simple", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Code patterns
|
||||||
|
code_score = sum(1 for p in _CODE_PATTERNS if p.search(query_stripped))
|
||||||
|
if code_score >= 2:
|
||||||
|
return ClassificationResult(
|
||||||
|
category="code",
|
||||||
|
model_hint=self._model_map.get("code", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Research patterns
|
||||||
|
research_score = sum(1 for p in _RESEARCH_PATTERNS if p.search(query_stripped))
|
||||||
|
if research_score >= 1:
|
||||||
|
return ClassificationResult(
|
||||||
|
category="research",
|
||||||
|
model_hint=self._model_map.get("research", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Long or complex queries
|
||||||
|
if len(query_stripped) > 500 or query_stripped.count("\n") > 5:
|
||||||
|
return ClassificationResult(
|
||||||
|
category="complex",
|
||||||
|
model_hint=self._model_map.get("complex", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default — no override
|
||||||
|
return ClassificationResult(category="complex", model_hint="")
|
||||||
@@ -61,6 +61,16 @@ class MCPServerConfig(BaseModel):
|
|||||||
url: str = ""
|
url: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class PerformanceConfig(BaseModel):
|
||||||
|
"""Performance tuning — caching, cost tracking, model routing."""
|
||||||
|
|
||||||
|
cache_ttl: int = 3600
|
||||||
|
daily_budget_usd: float = 0.0
|
||||||
|
monthly_budget_usd: float = 0.0
|
||||||
|
fallback_model: str = ""
|
||||||
|
model_routing: dict[str, str] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class OrchestratorConfig(BaseModel):
|
class OrchestratorConfig(BaseModel):
|
||||||
max_concurrent: int = 5
|
max_concurrent: int = 5
|
||||||
delegation_timeout: int = 120
|
delegation_timeout: int = 120
|
||||||
@@ -87,6 +97,7 @@ class Config(BaseModel):
|
|||||||
mcp_servers: dict[str, MCPServerConfig] = Field(default_factory=dict)
|
mcp_servers: dict[str, MCPServerConfig] = Field(default_factory=dict)
|
||||||
agents: dict[str, str] = Field(default_factory=dict)
|
agents: dict[str, str] = Field(default_factory=dict)
|
||||||
orchestrator: OrchestratorConfig = Field(default_factory=OrchestratorConfig)
|
orchestrator: OrchestratorConfig = Field(default_factory=OrchestratorConfig)
|
||||||
|
performance: PerformanceConfig = Field(default_factory=PerformanceConfig)
|
||||||
|
|
||||||
|
|
||||||
def load_config(path: str | Path = "config.yaml") -> Config:
|
def load_config(path: str | Path = "config.yaml") -> Config:
|
||||||
|
|||||||
173
xtrm_agent/cost.py
Normal file
173
xtrm_agent/cost.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
"""Cost tracking — monitor token usage and enforce budget limits."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
# Pricing per 1M tokens (input, output) in USD
|
||||||
|
MODEL_PRICING: dict[str, tuple[float, float]] = {
|
||||||
|
# Anthropic
|
||||||
|
"claude-sonnet-4-5-20250929": (3.0, 15.0),
|
||||||
|
"claude-haiku-4-5-20251001": (0.80, 4.0),
|
||||||
|
"claude-opus-4-6": (15.0, 75.0),
|
||||||
|
# NVIDIA NIM / LiteLLM
|
||||||
|
"nvidia_nim/deepseek-ai/deepseek-v3.1": (0.27, 1.10),
|
||||||
|
"nvidia_nim/moonshotai/kimi-k2.5": (0.35, 1.40),
|
||||||
|
"nvidia_nim/minimaxai/minimax-m2.1": (0.30, 1.10),
|
||||||
|
# Fallback
|
||||||
|
"_default": (1.0, 3.0),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_pricing(model: str) -> tuple[float, float]:
|
||||||
|
"""Get (input_price, output_price) per 1M tokens for a model."""
|
||||||
|
return MODEL_PRICING.get(model, MODEL_PRICING["_default"])
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CostRecord:
|
||||||
|
model: str
|
||||||
|
input_tokens: int
|
||||||
|
output_tokens: int
|
||||||
|
cost_usd: float
|
||||||
|
timestamp: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BudgetConfig:
|
||||||
|
daily_limit_usd: float = 0.0 # 0 = no limit
|
||||||
|
monthly_limit_usd: float = 0.0 # 0 = no limit
|
||||||
|
warning_threshold: float = 0.8 # warn at 80% of budget
|
||||||
|
|
||||||
|
|
||||||
|
class CostTracker:
|
||||||
|
"""Track LLM costs in SQLite and enforce budget limits."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
db_path: str | Path = "data/costs.db",
|
||||||
|
budget: BudgetConfig | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._db_path = str(db_path)
|
||||||
|
self._budget = budget or BudgetConfig()
|
||||||
|
self._db: aiosqlite.Connection | None = None
|
||||||
|
|
||||||
|
async def setup(self) -> None:
|
||||||
|
Path(self._db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._db = await aiosqlite.connect(self._db_path)
|
||||||
|
await self._db.execute("PRAGMA journal_mode=WAL")
|
||||||
|
await self._db.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS cost_records (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
model TEXT NOT NULL,
|
||||||
|
input_tokens INTEGER NOT NULL,
|
||||||
|
output_tokens INTEGER NOT NULL,
|
||||||
|
cost_usd REAL NOT NULL,
|
||||||
|
timestamp REAL NOT NULL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
await self._db.commit()
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
if self._db:
|
||||||
|
await self._db.close()
|
||||||
|
self._db = None
|
||||||
|
|
||||||
|
async def record(self, model: str, usage: dict[str, int]) -> CostRecord:
|
||||||
|
"""Record token usage and return the cost record."""
|
||||||
|
input_tokens = usage.get("input_tokens", 0)
|
||||||
|
output_tokens = usage.get("output_tokens", 0)
|
||||||
|
input_price, output_price = _get_pricing(model)
|
||||||
|
cost = (input_tokens * input_price + output_tokens * output_price) / 1_000_000
|
||||||
|
|
||||||
|
record = CostRecord(
|
||||||
|
model=model,
|
||||||
|
input_tokens=input_tokens,
|
||||||
|
output_tokens=output_tokens,
|
||||||
|
cost_usd=cost,
|
||||||
|
timestamp=time.time(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._db:
|
||||||
|
await self._db.execute(
|
||||||
|
"INSERT INTO cost_records (model, input_tokens, output_tokens, cost_usd, timestamp) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
(model, input_tokens, output_tokens, cost, record.timestamp),
|
||||||
|
)
|
||||||
|
await self._db.commit()
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Cost: {model} — {input_tokens}in/{output_tokens}out = ${cost:.6f}"
|
||||||
|
)
|
||||||
|
return record
|
||||||
|
|
||||||
|
async def check_budget(self) -> str:
|
||||||
|
"""Check budget status. Returns 'ok', 'warning', or 'exceeded'."""
|
||||||
|
if not self._db:
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
# Daily check
|
||||||
|
if self._budget.daily_limit_usd > 0:
|
||||||
|
day_start = now - (now % 86400)
|
||||||
|
daily_total = await self._sum_costs_since(day_start)
|
||||||
|
if daily_total >= self._budget.daily_limit_usd:
|
||||||
|
logger.warning(f"Daily budget EXCEEDED: ${daily_total:.4f} / ${self._budget.daily_limit_usd:.2f}")
|
||||||
|
return "exceeded"
|
||||||
|
if daily_total >= self._budget.daily_limit_usd * self._budget.warning_threshold:
|
||||||
|
logger.warning(f"Daily budget warning: ${daily_total:.4f} / ${self._budget.daily_limit_usd:.2f}")
|
||||||
|
return "warning"
|
||||||
|
|
||||||
|
# Monthly check
|
||||||
|
if self._budget.monthly_limit_usd > 0:
|
||||||
|
month_start = now - (now % (86400 * 30))
|
||||||
|
monthly_total = await self._sum_costs_since(month_start)
|
||||||
|
if monthly_total >= self._budget.monthly_limit_usd:
|
||||||
|
logger.warning(f"Monthly budget EXCEEDED: ${monthly_total:.4f} / ${self._budget.monthly_limit_usd:.2f}")
|
||||||
|
return "exceeded"
|
||||||
|
if monthly_total >= self._budget.monthly_limit_usd * self._budget.warning_threshold:
|
||||||
|
logger.warning(f"Monthly budget warning: ${monthly_total:.4f} / ${self._budget.monthly_limit_usd:.2f}")
|
||||||
|
return "warning"
|
||||||
|
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
async def get_summary(self) -> dict:
|
||||||
|
"""Get a cost summary with per-model breakdown."""
|
||||||
|
if not self._db:
|
||||||
|
return {"total_usd": 0, "models": {}}
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
day_start = now - (now % 86400)
|
||||||
|
|
||||||
|
rows: list = []
|
||||||
|
async with self._db.execute(
|
||||||
|
"SELECT model, SUM(input_tokens), SUM(output_tokens), SUM(cost_usd) FROM cost_records WHERE timestamp >= ? GROUP BY model",
|
||||||
|
(day_start,),
|
||||||
|
) as cursor:
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
|
||||||
|
models = {}
|
||||||
|
total = 0.0
|
||||||
|
for model, inp, out, cost in rows:
|
||||||
|
models[model] = {"input_tokens": inp, "output_tokens": out, "cost_usd": cost}
|
||||||
|
total += cost
|
||||||
|
|
||||||
|
return {"total_usd": total, "models": models}
|
||||||
|
|
||||||
|
async def _sum_costs_since(self, since: float) -> float:
|
||||||
|
if not self._db:
|
||||||
|
return 0.0
|
||||||
|
async with self._db.execute(
|
||||||
|
"SELECT COALESCE(SUM(cost_usd), 0) FROM cost_records WHERE timestamp >= ?",
|
||||||
|
(since,),
|
||||||
|
) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
return row[0] if row else 0.0
|
||||||
@@ -2,17 +2,45 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from xtrm_agent.cache import ResponseCache
|
||||||
|
from xtrm_agent.classifier import ClassificationResult, QueryClassifier
|
||||||
from xtrm_agent.config import AgentFileConfig
|
from xtrm_agent.config import AgentFileConfig
|
||||||
import json
|
from xtrm_agent.cost import CostTracker
|
||||||
|
|
||||||
from xtrm_agent.llm.provider import LLMProvider, LLMResponse
|
from xtrm_agent.llm.provider import LLMProvider, LLMResponse
|
||||||
|
from xtrm_agent.scrub import scrub_credentials
|
||||||
from xtrm_agent.tools.approval import ApprovalEngine
|
from xtrm_agent.tools.approval import ApprovalEngine
|
||||||
from xtrm_agent.tools.registry import ToolRegistry
|
from xtrm_agent.tools.registry import ToolRegistry
|
||||||
|
|
||||||
|
# History compaction settings
|
||||||
|
_MAX_HISTORY_MESSAGES = 50
|
||||||
|
_KEEP_RECENT = 20
|
||||||
|
_COMPACTION_PROMPT = (
|
||||||
|
"Summarize the following conversation history concisely. "
|
||||||
|
"Preserve key facts, decisions, tool results, and context needed to continue. "
|
||||||
|
"Be brief but complete."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Retry settings
|
||||||
|
_MAX_RETRIES = 3
|
||||||
|
_RETRY_BACKOFF_BASE = 2.0
|
||||||
|
_NON_RETRYABLE_KEYWORDS = [
|
||||||
|
"authentication",
|
||||||
|
"unauthorized",
|
||||||
|
"invalid api key",
|
||||||
|
"model not found",
|
||||||
|
"permission denied",
|
||||||
|
"invalid_api_key",
|
||||||
|
"401",
|
||||||
|
"403",
|
||||||
|
"404",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class Engine:
|
class Engine:
|
||||||
"""Runs one agent's LLM loop: messages → LLM → tool calls → loop → response."""
|
"""Runs one agent's LLM loop: messages → LLM → tool calls → loop → response."""
|
||||||
@@ -23,16 +51,33 @@ class Engine:
|
|||||||
provider: LLMProvider,
|
provider: LLMProvider,
|
||||||
tools: ToolRegistry,
|
tools: ToolRegistry,
|
||||||
approval: ApprovalEngine,
|
approval: ApprovalEngine,
|
||||||
|
cache: ResponseCache | None = None,
|
||||||
|
cost_tracker: CostTracker | None = None,
|
||||||
|
classifier: QueryClassifier | None = None,
|
||||||
|
fallback_provider: LLMProvider | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.config = agent_config
|
self.config = agent_config
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
self.approval = approval
|
self.approval = approval
|
||||||
|
self.cache = cache
|
||||||
|
self.cost_tracker = cost_tracker
|
||||||
|
self.classifier = classifier
|
||||||
|
self.fallback_provider = fallback_provider
|
||||||
|
|
||||||
async def run(self, user_message: str) -> str:
|
async def run(self, user_message: str) -> str:
|
||||||
"""Process a single user message through the agent loop."""
|
"""Process a single user message through the agent loop."""
|
||||||
messages = self._build_initial_messages(user_message)
|
messages = self._build_initial_messages(user_message)
|
||||||
return await self._agent_loop(messages)
|
|
||||||
|
# Query classification — override model if classifier suggests one
|
||||||
|
model_override = ""
|
||||||
|
if self.classifier:
|
||||||
|
result = self.classifier.classify(user_message)
|
||||||
|
if result.model_hint:
|
||||||
|
model_override = result.model_hint
|
||||||
|
logger.debug(f"[{self.config.name}] Classified as '{result.category}' → model={result.model_hint}")
|
||||||
|
|
||||||
|
return await self._agent_loop(messages, model_override=model_override)
|
||||||
|
|
||||||
async def run_delegation(self, task: str) -> str:
|
async def run_delegation(self, task: str) -> str:
|
||||||
"""Process a delegation task (no system prompt changes)."""
|
"""Process a delegation task (no system prompt changes)."""
|
||||||
@@ -46,22 +91,45 @@ class Engine:
|
|||||||
messages.append({"role": "user", "content": user_message})
|
messages.append({"role": "user", "content": user_message})
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def _agent_loop(self, messages: list[dict[str, Any]]) -> str:
|
async def _agent_loop(
|
||||||
"""Core agent iteration loop."""
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
model_override: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Core agent iteration loop with caching, cost tracking, retry, and compaction."""
|
||||||
for iteration in range(self.config.max_iterations):
|
for iteration in range(self.config.max_iterations):
|
||||||
model = self.config.model or self.provider.get_default_model()
|
model = model_override or self.config.model or self.provider.get_default_model()
|
||||||
tool_defs = self.tools.get_definitions() if self.tools.names() else None
|
tool_defs = self.tools.get_definitions() if self.tools.names() else None
|
||||||
|
|
||||||
response = await self.provider.complete(
|
# Budget check
|
||||||
|
if self.cost_tracker:
|
||||||
|
status = await self.cost_tracker.check_budget()
|
||||||
|
if status == "exceeded":
|
||||||
|
return "(budget exceeded — please try again later)"
|
||||||
|
|
||||||
|
# Response cache check (only for first iteration — no tool calls yet)
|
||||||
|
if self.cache and iteration == 0 and not tool_defs:
|
||||||
|
cached = await self.cache.get(model, messages)
|
||||||
|
if cached:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# LLM call with retry + fallback
|
||||||
|
response = await self._call_with_retry(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tool_defs,
|
tools=tool_defs,
|
||||||
model=model,
|
model=model,
|
||||||
max_tokens=8192,
|
|
||||||
temperature=self.config.temperature,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Track cost
|
||||||
|
if self.cost_tracker and response.usage:
|
||||||
|
await self.cost_tracker.record(model, response.usage)
|
||||||
|
|
||||||
if not response.has_tool_calls:
|
if not response.has_tool_calls:
|
||||||
return response.content or "(no response)"
|
content = response.content or "(no response)"
|
||||||
|
# Cache the final response (only if no tool calls were used in the conversation)
|
||||||
|
if self.cache and iteration == 0:
|
||||||
|
await self.cache.put(model, messages, content)
|
||||||
|
return content
|
||||||
|
|
||||||
# Add assistant message with tool calls
|
# Add assistant message with tool calls
|
||||||
messages.append(self._assistant_message(response))
|
messages.append(self._assistant_message(response))
|
||||||
@@ -69,6 +137,8 @@ class Engine:
|
|||||||
# Execute each tool call
|
# Execute each tool call
|
||||||
for tc in response.tool_calls:
|
for tc in response.tool_calls:
|
||||||
result = await self._execute_tool(tc.name, tc.arguments)
|
result = await self._execute_tool(tc.name, tc.arguments)
|
||||||
|
# Scrub credentials from tool output
|
||||||
|
result = scrub_credentials(result)
|
||||||
messages.append(
|
messages.append(
|
||||||
{
|
{
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
@@ -83,8 +153,121 @@ class Engine:
|
|||||||
f"{len(response.tool_calls)} tool call(s)"
|
f"{len(response.tool_calls)} tool call(s)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# History compaction — summarize old messages if history is too long
|
||||||
|
messages = await self._compact_history(messages, model)
|
||||||
|
|
||||||
return "(max iterations reached)"
|
return "(max iterations reached)"
|
||||||
|
|
||||||
|
async def _call_with_retry(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None,
|
||||||
|
model: str,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Call the LLM with retry and fallback."""
|
||||||
|
last_error: Exception | None = None
|
||||||
|
|
||||||
|
for attempt in range(_MAX_RETRIES):
|
||||||
|
try:
|
||||||
|
return await self.provider.complete(
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
model=model,
|
||||||
|
max_tokens=8192,
|
||||||
|
temperature=self.config.temperature,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
error_str = str(e).lower()
|
||||||
|
|
||||||
|
# Don't retry non-retryable errors
|
||||||
|
if any(kw in error_str for kw in _NON_RETRYABLE_KEYWORDS):
|
||||||
|
logger.error(f"[{self.config.name}] Non-retryable error: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
wait = _RETRY_BACKOFF_BASE ** attempt
|
||||||
|
logger.warning(
|
||||||
|
f"[{self.config.name}] Attempt {attempt + 1}/{_MAX_RETRIES} failed: {e} — retrying in {wait}s"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(wait)
|
||||||
|
|
||||||
|
# Try fallback provider
|
||||||
|
if self.fallback_provider:
|
||||||
|
logger.info(f"[{self.config.name}] Trying fallback provider")
|
||||||
|
try:
|
||||||
|
return await self.fallback_provider.complete(
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
model=self.fallback_provider.get_default_model(),
|
||||||
|
max_tokens=8192,
|
||||||
|
temperature=self.config.temperature,
|
||||||
|
)
|
||||||
|
except Exception as fallback_err:
|
||||||
|
logger.error(f"[{self.config.name}] Fallback also failed: {fallback_err}")
|
||||||
|
|
||||||
|
# All retries and fallback exhausted
|
||||||
|
return LLMResponse(
|
||||||
|
content=f"(error after {_MAX_RETRIES} retries: {last_error})",
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _compact_history(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
model: str,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Compact old messages if history exceeds threshold."""
|
||||||
|
# Count non-system messages
|
||||||
|
non_system = [m for m in messages if m["role"] != "system"]
|
||||||
|
if len(non_system) <= _MAX_HISTORY_MESSAGES:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
# Split: keep system messages, summarize old, keep recent
|
||||||
|
system_msgs = [m for m in messages if m["role"] == "system"]
|
||||||
|
old_msgs = non_system[:-_KEEP_RECENT]
|
||||||
|
recent_msgs = non_system[-_KEEP_RECENT:]
|
||||||
|
|
||||||
|
# Ask LLM to summarize the old messages
|
||||||
|
summary_messages = [
|
||||||
|
{"role": "system", "content": _COMPACTION_PROMPT},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": json.dumps(
|
||||||
|
[{"role": m["role"], "content": m.get("content", "")} for m in old_msgs],
|
||||||
|
indent=None,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
summary_response = await self.provider.complete(
|
||||||
|
messages=summary_messages,
|
||||||
|
model=model,
|
||||||
|
max_tokens=2048,
|
||||||
|
temperature=0.1,
|
||||||
|
)
|
||||||
|
summary_text = summary_response.content or ""
|
||||||
|
logger.info(
|
||||||
|
f"[{self.config.name}] Compacted {len(old_msgs)} messages → "
|
||||||
|
f"{len(summary_text)} char summary"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[{self.config.name}] History compaction failed: {e}")
|
||||||
|
return messages
|
||||||
|
|
||||||
|
# Rebuild: system + summary + recent
|
||||||
|
compacted = list(system_msgs)
|
||||||
|
compacted.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": f"[Conversation summary]\n{summary_text}",
|
||||||
|
})
|
||||||
|
compacted.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Understood, I have the context from our previous conversation.",
|
||||||
|
})
|
||||||
|
compacted.extend(recent_msgs)
|
||||||
|
return compacted
|
||||||
|
|
||||||
async def _execute_tool(self, name: str, arguments: dict[str, Any]) -> str:
|
async def _execute_tool(self, name: str, arguments: dict[str, Any]) -> str:
|
||||||
"""Execute a tool with approval check."""
|
"""Execute a tool with approval check."""
|
||||||
approved = await self.approval.check(name, arguments)
|
approved = await self.approval.check(name, arguments)
|
||||||
|
|||||||
@@ -10,7 +10,10 @@ from typing import Any
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from xtrm_agent.bus import AgentMessage, InboundMessage, MessageBus, OutboundMessage
|
from xtrm_agent.bus import AgentMessage, InboundMessage, MessageBus, OutboundMessage
|
||||||
|
from xtrm_agent.cache import ResponseCache
|
||||||
|
from xtrm_agent.classifier import QueryClassifier
|
||||||
from xtrm_agent.config import Config, AgentFileConfig, parse_agent_file
|
from xtrm_agent.config import Config, AgentFileConfig, parse_agent_file
|
||||||
|
from xtrm_agent.cost import BudgetConfig, CostTracker
|
||||||
from xtrm_agent.engine import Engine
|
from xtrm_agent.engine import Engine
|
||||||
from xtrm_agent.llm.anthropic import AnthropicProvider
|
from xtrm_agent.llm.anthropic import AnthropicProvider
|
||||||
from xtrm_agent.llm.litellm import LiteLLMProvider
|
from xtrm_agent.llm.litellm import LiteLLMProvider
|
||||||
@@ -35,6 +38,8 @@ class Orchestrator:
|
|||||||
self._agent_configs: dict[str, AgentFileConfig] = {}
|
self._agent_configs: dict[str, AgentFileConfig] = {}
|
||||||
self._mcp_stack = AsyncExitStack()
|
self._mcp_stack = AsyncExitStack()
|
||||||
self._running = False
|
self._running = False
|
||||||
|
self._cache: ResponseCache | None = None
|
||||||
|
self._cost_tracker: CostTracker | None = None
|
||||||
|
|
||||||
# Channel defaults for routing
|
# Channel defaults for routing
|
||||||
channel_defaults = {}
|
channel_defaults = {}
|
||||||
@@ -53,6 +58,27 @@ class Orchestrator:
|
|||||||
workspace = Path(self.config.tools.workspace).resolve()
|
workspace = Path(self.config.tools.workspace).resolve()
|
||||||
workspace.mkdir(parents=True, exist_ok=True)
|
workspace.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Initialize shared response cache
|
||||||
|
self._cache = ResponseCache(
|
||||||
|
db_path=workspace / "cache.db",
|
||||||
|
ttl=self.config.performance.cache_ttl,
|
||||||
|
)
|
||||||
|
await self._cache.setup()
|
||||||
|
|
||||||
|
# Initialize cost tracker
|
||||||
|
budget = BudgetConfig(
|
||||||
|
daily_limit_usd=self.config.performance.daily_budget_usd,
|
||||||
|
monthly_limit_usd=self.config.performance.monthly_budget_usd,
|
||||||
|
)
|
||||||
|
self._cost_tracker = CostTracker(
|
||||||
|
db_path=workspace / "costs.db",
|
||||||
|
budget=budget,
|
||||||
|
)
|
||||||
|
await self._cost_tracker.setup()
|
||||||
|
|
||||||
|
# Initialize query classifier
|
||||||
|
classifier = QueryClassifier(model_map=self.config.performance.model_routing)
|
||||||
|
|
||||||
# Parse all agent definitions
|
# Parse all agent definitions
|
||||||
for agent_name, agent_path in self.config.agents.items():
|
for agent_name, agent_path in self.config.agents.items():
|
||||||
p = Path(agent_path)
|
p = Path(agent_path)
|
||||||
@@ -74,6 +100,10 @@ class Orchestrator:
|
|||||||
await self._mcp_stack.__aenter__()
|
await self._mcp_stack.__aenter__()
|
||||||
await connect_mcp_servers(self.config.mcp_servers, global_registry, self._mcp_stack)
|
await connect_mcp_servers(self.config.mcp_servers, global_registry, self._mcp_stack)
|
||||||
|
|
||||||
|
# Create fallback provider (LiteLLM with a cheap model)
|
||||||
|
fallback_model = self.config.performance.fallback_model
|
||||||
|
fallback_provider = LiteLLMProvider(model=fallback_model) if fallback_model else None
|
||||||
|
|
||||||
# Create one engine per agent
|
# Create one engine per agent
|
||||||
agent_names = list(self._agent_configs.keys())
|
agent_names = list(self._agent_configs.keys())
|
||||||
for agent_name, agent_cfg in self._agent_configs.items():
|
for agent_name, agent_cfg in self._agent_configs.items():
|
||||||
@@ -107,6 +137,10 @@ class Orchestrator:
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
tools=agent_registry,
|
tools=agent_registry,
|
||||||
approval=approval,
|
approval=approval,
|
||||||
|
cache=self._cache,
|
||||||
|
cost_tracker=self._cost_tracker,
|
||||||
|
classifier=classifier,
|
||||||
|
fallback_provider=fallback_provider,
|
||||||
)
|
)
|
||||||
self._engines[agent_name] = engine
|
self._engines[agent_name] = engine
|
||||||
|
|
||||||
@@ -190,6 +224,10 @@ class Orchestrator:
|
|||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
self._running = False
|
self._running = False
|
||||||
|
if self._cache:
|
||||||
|
await self._cache.close()
|
||||||
|
if self._cost_tracker:
|
||||||
|
await self._cost_tracker.close()
|
||||||
await self._mcp_stack.aclose()
|
await self._mcp_stack.aclose()
|
||||||
logger.info("Orchestrator stopped")
|
logger.info("Orchestrator stopped")
|
||||||
|
|
||||||
|
|||||||
46
xtrm_agent/scrub.py
Normal file
46
xtrm_agent/scrub.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""Credential scrubbing — prevent secret leakage in tool output."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Patterns that match common secret formats
|
||||||
|
_SECRET_PATTERNS = [
|
||||||
|
# Key=value patterns (API keys, tokens, passwords)
|
||||||
|
re.compile(
|
||||||
|
r"(?i)(api[_-]?key|token|password|passwd|secret|access[_-]?key|private[_-]?key|auth)"
|
||||||
|
r"[\s]*[=:]\s*['\"]?([^\s'\"]{8,})['\"]?",
|
||||||
|
),
|
||||||
|
# Bearer tokens
|
||||||
|
re.compile(r"Bearer\s+[A-Za-z0-9\-._~+/]+=*", re.IGNORECASE),
|
||||||
|
# AWS access keys (AKIA...)
|
||||||
|
re.compile(r"AKIA[0-9A-Z]{16}"),
|
||||||
|
# AWS secret keys (40 char base64)
|
||||||
|
re.compile(r"(?i)aws[_-]?secret[_-]?access[_-]?key[\s]*[=:]\s*['\"]?([A-Za-z0-9/+=]{40})['\"]?"),
|
||||||
|
# GitHub tokens
|
||||||
|
re.compile(r"gh[pousr]_[A-Za-z0-9_]{36,}"),
|
||||||
|
# Generic long hex strings that look like secrets (32+ hex chars after key= or token=)
|
||||||
|
re.compile(r"(?i)(?:key|token|secret)[=:]\s*['\"]?([0-9a-f]{32,})['\"]?"),
|
||||||
|
]
|
||||||
|
|
||||||
|
_REDACTED = "[REDACTED]"
|
||||||
|
|
||||||
|
|
||||||
|
def scrub_credentials(text: str) -> str:
|
||||||
|
"""Scrub potential secrets from text, replacing with [REDACTED]."""
|
||||||
|
result = text
|
||||||
|
for pattern in _SECRET_PATTERNS:
|
||||||
|
result = pattern.sub(_redact_match, result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _redact_match(match: re.Match) -> str:
|
||||||
|
"""Replace the secret value while keeping the key name visible."""
|
||||||
|
full = match.group(0)
|
||||||
|
# For key=value patterns, keep the key part
|
||||||
|
for sep in ("=", ":"):
|
||||||
|
if sep in full:
|
||||||
|
key_part = full[: full.index(sep) + 1]
|
||||||
|
return f"{key_part} {_REDACTED}"
|
||||||
|
# For standalone patterns (Bearer, AKIA, gh*_), redact the whole thing
|
||||||
|
return _REDACTED
|
||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -15,7 +16,7 @@ class ApprovalPolicy(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class ApprovalEngine:
|
class ApprovalEngine:
|
||||||
"""Deny-by-default tool approval."""
|
"""Deny-by-default tool approval with session allowlist and audit log."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -26,6 +27,8 @@ class ApprovalEngine:
|
|||||||
self._auto_approve = set(auto_approve or [])
|
self._auto_approve = set(auto_approve or [])
|
||||||
self._require_approval = set(require_approval or [])
|
self._require_approval = set(require_approval or [])
|
||||||
self._interactive = interactive
|
self._interactive = interactive
|
||||||
|
self._session_allowed: set[str] = set()
|
||||||
|
self._audit_log: list[dict[str, Any]] = []
|
||||||
|
|
||||||
def get_policy(self, tool_name: str) -> ApprovalPolicy:
|
def get_policy(self, tool_name: str) -> ApprovalPolicy:
|
||||||
"""Get the approval policy for a tool."""
|
"""Get the approval policy for a tool."""
|
||||||
@@ -43,27 +46,54 @@ class ApprovalEngine:
|
|||||||
policy = self.get_policy(tool_name)
|
policy = self.get_policy(tool_name)
|
||||||
|
|
||||||
if policy == ApprovalPolicy.AUTO_APPROVE:
|
if policy == ApprovalPolicy.AUTO_APPROVE:
|
||||||
|
self._log_decision(tool_name, arguments, "auto_approved")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Session-scoped "always allow"
|
||||||
|
if tool_name in self._session_allowed:
|
||||||
|
self._log_decision(tool_name, arguments, "session_allowed")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if policy == ApprovalPolicy.DENY:
|
if policy == ApprovalPolicy.DENY:
|
||||||
logger.warning(f"Tool '{tool_name}' denied by policy")
|
logger.warning(f"Tool '{tool_name}' denied by policy")
|
||||||
|
self._log_decision(tool_name, arguments, "denied")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# REQUIRE_APPROVAL
|
# REQUIRE_APPROVAL
|
||||||
if not self._interactive:
|
if not self._interactive:
|
||||||
logger.warning(f"Tool '{tool_name}' requires approval but running non-interactively — denied")
|
logger.warning(f"Tool '{tool_name}' requires approval but running non-interactively — denied")
|
||||||
|
self._log_decision(tool_name, arguments, "denied_non_interactive")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# In interactive mode, prompt the user
|
# In interactive mode, prompt the user
|
||||||
logger.info(f"Tool '{tool_name}' requires approval. Args: {arguments}")
|
logger.info(f"Tool '{tool_name}' requires approval. Args: {arguments}")
|
||||||
return await self._prompt_user(tool_name, arguments)
|
approved, always = await self._prompt_user(tool_name, arguments)
|
||||||
|
if approved and always:
|
||||||
|
self._session_allowed.add(tool_name)
|
||||||
|
self._log_decision(tool_name, arguments, "user_approved" if approved else "user_denied")
|
||||||
|
return approved
|
||||||
|
|
||||||
async def _prompt_user(self, tool_name: str, arguments: dict[str, Any]) -> bool:
|
async def _prompt_user(self, tool_name: str, arguments: dict[str, Any]) -> tuple[bool, bool]:
|
||||||
"""Prompt user for tool approval (interactive mode)."""
|
"""Prompt user for tool approval. Returns (approved, always_allow)."""
|
||||||
print(f"\n[APPROVAL REQUIRED] Tool: {tool_name}")
|
print(f"\n[APPROVAL REQUIRED] Tool: {tool_name}")
|
||||||
print(f" Arguments: {arguments}")
|
print(f" Arguments: {arguments}")
|
||||||
try:
|
try:
|
||||||
answer = input(" Allow? [y/N]: ").strip().lower()
|
answer = input(" Allow? [y/N/a(lways)]: ").strip().lower()
|
||||||
return answer in ("y", "yes")
|
if answer in ("a", "always"):
|
||||||
|
return True, True
|
||||||
|
return answer in ("y", "yes"), False
|
||||||
except (EOFError, KeyboardInterrupt):
|
except (EOFError, KeyboardInterrupt):
|
||||||
return False
|
return False, False
|
||||||
|
|
||||||
|
def _log_decision(self, tool_name: str, arguments: dict[str, Any], decision: str) -> None:
|
||||||
|
"""Record an approval decision in the audit log."""
|
||||||
|
self._audit_log.append({
|
||||||
|
"tool": tool_name,
|
||||||
|
"arguments": arguments,
|
||||||
|
"decision": decision,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
})
|
||||||
|
|
||||||
|
def get_audit_log(self) -> list[dict[str, Any]]:
|
||||||
|
"""Return the audit log for inspection."""
|
||||||
|
return list(self._audit_log)
|
||||||
|
|||||||
Reference in New Issue
Block a user