Compare commits
2 Commits
b3608b35fa
...
872ed24f0c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
872ed24f0c | ||
|
|
e24e3026b6 |
@@ -10,6 +10,7 @@ tools:
|
|||||||
- edit_file
|
- edit_file
|
||||||
- list_dir
|
- list_dir
|
||||||
- bash
|
- bash
|
||||||
|
- web_search
|
||||||
- delegate
|
- delegate
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ model: nvidia_nim/deepseek-ai/deepseek-v3.1
|
|||||||
temperature: 0.5
|
temperature: 0.5
|
||||||
max_iterations: 20
|
max_iterations: 20
|
||||||
tools:
|
tools:
|
||||||
|
- web_search
|
||||||
- web_fetch
|
- web_fetch
|
||||||
- read_file
|
- read_file
|
||||||
- list_dir
|
- list_dir
|
||||||
|
|||||||
10
config.yaml
10
config.yaml
@@ -27,6 +27,7 @@ tools:
|
|||||||
- read_file
|
- read_file
|
||||||
- list_dir
|
- list_dir
|
||||||
- web_fetch
|
- web_fetch
|
||||||
|
- web_search
|
||||||
- delegate
|
- delegate
|
||||||
- write_file
|
- write_file
|
||||||
- edit_file
|
- edit_file
|
||||||
@@ -42,3 +43,12 @@ agents:
|
|||||||
orchestrator:
|
orchestrator:
|
||||||
max_concurrent: 5
|
max_concurrent: 5
|
||||||
delegation_timeout: 120
|
delegation_timeout: 120
|
||||||
|
|
||||||
|
performance:
|
||||||
|
cache_ttl: 3600
|
||||||
|
daily_budget_usd: 5.0
|
||||||
|
monthly_budget_usd: 50.0
|
||||||
|
fallback_model: nvidia_nim/deepseek-ai/deepseek-v3.1
|
||||||
|
model_routing:
|
||||||
|
simple: nvidia_nim/deepseek-ai/deepseek-v3.1
|
||||||
|
research: nvidia_nim/moonshotai/kimi-k2.5
|
||||||
|
|||||||
@@ -17,6 +17,9 @@ dependencies = [
|
|||||||
"httpx>=0.28.0",
|
"httpx>=0.28.0",
|
||||||
"loguru>=0.7.0",
|
"loguru>=0.7.0",
|
||||||
"json-repair>=0.30.0",
|
"json-repair>=0.30.0",
|
||||||
|
"duckduckgo-search>=7.0.0",
|
||||||
|
"pypdf>=5.0.0",
|
||||||
|
"aiosqlite>=0.20.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
@@ -8,6 +8,14 @@ from datetime import datetime, timezone
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Attachment:
|
||||||
|
"""A text-extracted attachment from a user message."""
|
||||||
|
|
||||||
|
filename: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InboundMessage:
|
class InboundMessage:
|
||||||
"""Message from a channel (user) heading to an agent."""
|
"""Message from a channel (user) heading to an agent."""
|
||||||
@@ -19,6 +27,7 @@ class InboundMessage:
|
|||||||
target_agent: str | None = None
|
target_agent: str | None = None
|
||||||
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
attachments: list[Attachment] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
100
xtrm_agent/cache.py
Normal file
100
xtrm_agent/cache.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
"""Response cache — avoid redundant LLM calls for identical prompts."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseCache:
|
||||||
|
"""SQLite-backed LLM response cache with TTL expiry."""
|
||||||
|
|
||||||
|
def __init__(self, db_path: str | Path = "data/cache.db", ttl: int = 3600) -> None:
|
||||||
|
self._db_path = str(db_path)
|
||||||
|
self._ttl = ttl
|
||||||
|
self._db: aiosqlite.Connection | None = None
|
||||||
|
|
||||||
|
async def setup(self) -> None:
|
||||||
|
"""Create the cache table if it doesn't exist."""
|
||||||
|
Path(self._db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._db = await aiosqlite.connect(self._db_path)
|
||||||
|
await self._db.execute("PRAGMA journal_mode=WAL")
|
||||||
|
await self._db.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS response_cache (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
response TEXT NOT NULL,
|
||||||
|
model TEXT NOT NULL,
|
||||||
|
created_at REAL NOT NULL,
|
||||||
|
hits INTEGER DEFAULT 0
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
await self._db.commit()
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
if self._db:
|
||||||
|
await self._db.close()
|
||||||
|
self._db = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_key(model: str, messages: list[dict]) -> str:
|
||||||
|
"""SHA-256 hash of model + messages for cache key."""
|
||||||
|
raw = json.dumps({"model": model, "messages": messages}, sort_keys=True)
|
||||||
|
return hashlib.sha256(raw.encode()).hexdigest()
|
||||||
|
|
||||||
|
async def get(self, model: str, messages: list[dict]) -> str | None:
|
||||||
|
"""Look up a cached response. Returns None on miss or expired."""
|
||||||
|
if not self._db:
|
||||||
|
return None
|
||||||
|
key = self._make_key(model, messages)
|
||||||
|
now = time.time()
|
||||||
|
async with self._db.execute(
|
||||||
|
"SELECT response, created_at FROM response_cache WHERE key = ?",
|
||||||
|
(key,),
|
||||||
|
) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if not row:
|
||||||
|
return None
|
||||||
|
response, created_at = row
|
||||||
|
if now - created_at > self._ttl:
|
||||||
|
await self._db.execute("DELETE FROM response_cache WHERE key = ?", (key,))
|
||||||
|
await self._db.commit()
|
||||||
|
return None
|
||||||
|
# Bump hit count
|
||||||
|
await self._db.execute(
|
||||||
|
"UPDATE response_cache SET hits = hits + 1 WHERE key = ?", (key,)
|
||||||
|
)
|
||||||
|
await self._db.commit()
|
||||||
|
logger.debug(f"Cache hit for {model} (key={key[:12]}...)")
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def put(self, model: str, messages: list[dict], response: str) -> None:
|
||||||
|
"""Store a response in the cache."""
|
||||||
|
if not self._db:
|
||||||
|
return
|
||||||
|
key = self._make_key(model, messages)
|
||||||
|
await self._db.execute(
|
||||||
|
"""
|
||||||
|
INSERT OR REPLACE INTO response_cache (key, response, model, created_at, hits)
|
||||||
|
VALUES (?, ?, ?, ?, 0)
|
||||||
|
""",
|
||||||
|
(key, response, model, time.time()),
|
||||||
|
)
|
||||||
|
await self._db.commit()
|
||||||
|
|
||||||
|
async def clear_expired(self) -> int:
|
||||||
|
"""Remove expired entries. Returns count of deleted rows."""
|
||||||
|
if not self._db:
|
||||||
|
return 0
|
||||||
|
cutoff = time.time() - self._ttl
|
||||||
|
cursor = await self._db.execute(
|
||||||
|
"DELETE FROM response_cache WHERE created_at < ?", (cutoff,)
|
||||||
|
)
|
||||||
|
await self._db.commit()
|
||||||
|
return cursor.rowcount
|
||||||
@@ -3,14 +3,29 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import io
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
import httpx
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from xtrm_agent.bus import InboundMessage, MessageBus, OutboundMessage
|
from xtrm_agent.bus import Attachment, InboundMessage, MessageBus, OutboundMessage
|
||||||
from xtrm_agent.channels.base import BaseChannel
|
from xtrm_agent.channels.base import BaseChannel
|
||||||
|
|
||||||
|
# Extensions treated as plain text (decoded as UTF-8)
|
||||||
|
_TEXT_EXTENSIONS = frozenset({
|
||||||
|
".txt", ".py", ".md", ".json", ".yaml", ".yml", ".csv", ".log",
|
||||||
|
".js", ".ts", ".html", ".css", ".xml", ".toml", ".ini", ".sh",
|
||||||
|
".sql", ".rs", ".go", ".java", ".c", ".cpp", ".h", ".rb", ".php",
|
||||||
|
".swift", ".kt", ".r", ".cfg", ".env", ".conf", ".dockerfile",
|
||||||
|
".makefile", ".bat", ".ps1", ".lua", ".zig", ".hs",
|
||||||
|
})
|
||||||
|
|
||||||
|
_IMAGE_EXTENSIONS = frozenset({".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".svg"})
|
||||||
|
|
||||||
|
_MAX_ATTACHMENT_SIZE = 1_024_000 # 1 MB
|
||||||
|
|
||||||
|
|
||||||
class DiscordChannel(BaseChannel):
|
class DiscordChannel(BaseChannel):
|
||||||
"""Discord bot channel."""
|
"""Discord bot channel."""
|
||||||
@@ -54,12 +69,15 @@ class DiscordChannel(BaseChannel):
|
|||||||
if self.client.user:
|
if self.client.user:
|
||||||
content = content.replace(f"<@{self.client.user.id}>", "").strip()
|
content = content.replace(f"<@{self.client.user.id}>", "").strip()
|
||||||
|
|
||||||
|
attachments = await self._extract_attachments(message.attachments)
|
||||||
|
|
||||||
msg = InboundMessage(
|
msg = InboundMessage(
|
||||||
channel="discord",
|
channel="discord",
|
||||||
sender_id=str(message.author.id),
|
sender_id=str(message.author.id),
|
||||||
chat_id=str(message.channel.id),
|
chat_id=str(message.channel.id),
|
||||||
content=content,
|
content=content,
|
||||||
metadata={"guild_id": str(message.guild.id) if message.guild else ""},
|
metadata={"guild_id": str(message.guild.id) if message.guild else ""},
|
||||||
|
attachments=attachments,
|
||||||
)
|
)
|
||||||
await self.bus.publish_inbound(msg)
|
await self.bus.publish_inbound(msg)
|
||||||
|
|
||||||
@@ -71,6 +89,80 @@ class DiscordChannel(BaseChannel):
|
|||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
await message.channel.send("Sorry, I timed out processing your request.")
|
await message.channel.send("Sorry, I timed out processing your request.")
|
||||||
|
|
||||||
|
async def _extract_attachments(
|
||||||
|
self, discord_attachments: list[discord.Attachment]
|
||||||
|
) -> list[Attachment]:
|
||||||
|
"""Download Discord attachments and extract text content."""
|
||||||
|
results: list[Attachment] = []
|
||||||
|
for att in discord_attachments:
|
||||||
|
name = att.filename.lower()
|
||||||
|
ext = "." + name.rsplit(".", 1)[-1] if "." in name else ""
|
||||||
|
|
||||||
|
if att.size > _MAX_ATTACHMENT_SIZE:
|
||||||
|
results.append(Attachment(
|
||||||
|
filename=att.filename,
|
||||||
|
content=f"(file skipped — {att.size / 1_048_576:.1f} MB exceeds 1 MB limit)",
|
||||||
|
))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ext in _IMAGE_EXTENSIONS:
|
||||||
|
results.append(Attachment(
|
||||||
|
filename=att.filename,
|
||||||
|
content="(image attached — cannot read image content)",
|
||||||
|
))
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=30) as client:
|
||||||
|
resp = await client.get(att.url)
|
||||||
|
resp.raise_for_status()
|
||||||
|
raw = resp.content
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to download attachment {att.filename}: {e}")
|
||||||
|
results.append(Attachment(
|
||||||
|
filename=att.filename,
|
||||||
|
content=f"(failed to download: {e})",
|
||||||
|
))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ext == ".pdf":
|
||||||
|
try:
|
||||||
|
from pypdf import PdfReader
|
||||||
|
|
||||||
|
reader = PdfReader(io.BytesIO(raw))
|
||||||
|
text = "\n".join(
|
||||||
|
page.extract_text() or "" for page in reader.pages
|
||||||
|
).strip()
|
||||||
|
if text:
|
||||||
|
results.append(Attachment(filename=att.filename, content=text))
|
||||||
|
else:
|
||||||
|
results.append(Attachment(
|
||||||
|
filename=att.filename,
|
||||||
|
content="(PDF has no extractable text)",
|
||||||
|
))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to extract PDF text from {att.filename}: {e}")
|
||||||
|
results.append(Attachment(
|
||||||
|
filename=att.filename,
|
||||||
|
content=f"(failed to read PDF: {e})",
|
||||||
|
))
|
||||||
|
elif ext in _TEXT_EXTENSIONS or (att.content_type and att.content_type.startswith("text/")):
|
||||||
|
try:
|
||||||
|
text = raw.decode("utf-8", errors="replace")
|
||||||
|
results.append(Attachment(filename=att.filename, content=text))
|
||||||
|
except Exception as e:
|
||||||
|
results.append(Attachment(
|
||||||
|
filename=att.filename,
|
||||||
|
content=f"(failed to decode text: {e})",
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
results.append(Attachment(
|
||||||
|
filename=att.filename,
|
||||||
|
content=f"(unsupported file type: {ext or 'unknown'})",
|
||||||
|
))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
async def _send_chunked(
|
async def _send_chunked(
|
||||||
self, channel: discord.abc.Messageable, content: str
|
self, channel: discord.abc.Messageable, content: str
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
86
xtrm_agent/classifier.py
Normal file
86
xtrm_agent/classifier.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
"""Query classifier — route queries to appropriate models."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClassificationResult:
|
||||||
|
"""Result of classifying a query."""
|
||||||
|
|
||||||
|
category: str # "simple", "code", "research", "complex"
|
||||||
|
model_hint: str # suggested model override, or "" to use default
|
||||||
|
|
||||||
|
|
||||||
|
# Short, trivial queries — greetings, thanks, yes/no, etc.
|
||||||
|
_SIMPLE_PATTERNS = [
|
||||||
|
re.compile(r"^(hi|hello|hey|thanks|thank you|ok|yes|no|sure|bye|good\s*(morning|evening|night))[\s!.?]*$", re.IGNORECASE),
|
||||||
|
re.compile(r"^(what time|what date|what day)[\s\S]{0,20}\??\s*$", re.IGNORECASE),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Code-related queries
|
||||||
|
_CODE_PATTERNS = [
|
||||||
|
re.compile(r"\b(function|class|def|import|const|let|var|return|async|await)\b"),
|
||||||
|
re.compile(r"\b(refactor|debug|fix|implement|code|compile|build|test|lint|deploy)\b", re.IGNORECASE),
|
||||||
|
re.compile(r"```"),
|
||||||
|
re.compile(r"\.(py|js|ts|rs|go|java|cpp|c|rb|php|swift|kt)\b"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Research / information queries
|
||||||
|
_RESEARCH_PATTERNS = [
|
||||||
|
re.compile(r"\b(search|find|look\s*up|research|summarize|explain|compare|analyze)\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\b(what is|how does|why does|can you explain)\b", re.IGNORECASE),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class QueryClassifier:
|
||||||
|
"""Classify queries for model routing.
|
||||||
|
|
||||||
|
Configure with a mapping of category → model override.
|
||||||
|
Categories: "simple", "code", "research", "complex".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_map: dict[str, str] | None = None) -> None:
|
||||||
|
self._model_map = model_map or {}
|
||||||
|
|
||||||
|
def classify(self, query: str) -> ClassificationResult:
|
||||||
|
"""Classify a query and suggest a model."""
|
||||||
|
query_stripped = query.strip()
|
||||||
|
|
||||||
|
# Very short queries are likely simple
|
||||||
|
if len(query_stripped) < 20:
|
||||||
|
for pattern in _SIMPLE_PATTERNS:
|
||||||
|
if pattern.match(query_stripped):
|
||||||
|
return ClassificationResult(
|
||||||
|
category="simple",
|
||||||
|
model_hint=self._model_map.get("simple", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Code patterns
|
||||||
|
code_score = sum(1 for p in _CODE_PATTERNS if p.search(query_stripped))
|
||||||
|
if code_score >= 2:
|
||||||
|
return ClassificationResult(
|
||||||
|
category="code",
|
||||||
|
model_hint=self._model_map.get("code", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Research patterns
|
||||||
|
research_score = sum(1 for p in _RESEARCH_PATTERNS if p.search(query_stripped))
|
||||||
|
if research_score >= 1:
|
||||||
|
return ClassificationResult(
|
||||||
|
category="research",
|
||||||
|
model_hint=self._model_map.get("research", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Long or complex queries
|
||||||
|
if len(query_stripped) > 500 or query_stripped.count("\n") > 5:
|
||||||
|
return ClassificationResult(
|
||||||
|
category="complex",
|
||||||
|
model_hint=self._model_map.get("complex", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default — no override
|
||||||
|
return ClassificationResult(category="complex", model_hint="")
|
||||||
@@ -61,6 +61,16 @@ class MCPServerConfig(BaseModel):
|
|||||||
url: str = ""
|
url: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class PerformanceConfig(BaseModel):
|
||||||
|
"""Performance tuning — caching, cost tracking, model routing."""
|
||||||
|
|
||||||
|
cache_ttl: int = 3600
|
||||||
|
daily_budget_usd: float = 0.0
|
||||||
|
monthly_budget_usd: float = 0.0
|
||||||
|
fallback_model: str = ""
|
||||||
|
model_routing: dict[str, str] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class OrchestratorConfig(BaseModel):
|
class OrchestratorConfig(BaseModel):
|
||||||
max_concurrent: int = 5
|
max_concurrent: int = 5
|
||||||
delegation_timeout: int = 120
|
delegation_timeout: int = 120
|
||||||
@@ -87,6 +97,7 @@ class Config(BaseModel):
|
|||||||
mcp_servers: dict[str, MCPServerConfig] = Field(default_factory=dict)
|
mcp_servers: dict[str, MCPServerConfig] = Field(default_factory=dict)
|
||||||
agents: dict[str, str] = Field(default_factory=dict)
|
agents: dict[str, str] = Field(default_factory=dict)
|
||||||
orchestrator: OrchestratorConfig = Field(default_factory=OrchestratorConfig)
|
orchestrator: OrchestratorConfig = Field(default_factory=OrchestratorConfig)
|
||||||
|
performance: PerformanceConfig = Field(default_factory=PerformanceConfig)
|
||||||
|
|
||||||
|
|
||||||
def load_config(path: str | Path = "config.yaml") -> Config:
|
def load_config(path: str | Path = "config.yaml") -> Config:
|
||||||
|
|||||||
173
xtrm_agent/cost.py
Normal file
173
xtrm_agent/cost.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
"""Cost tracking — monitor token usage and enforce budget limits."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
# Pricing per 1M tokens (input, output) in USD
|
||||||
|
MODEL_PRICING: dict[str, tuple[float, float]] = {
|
||||||
|
# Anthropic
|
||||||
|
"claude-sonnet-4-5-20250929": (3.0, 15.0),
|
||||||
|
"claude-haiku-4-5-20251001": (0.80, 4.0),
|
||||||
|
"claude-opus-4-6": (15.0, 75.0),
|
||||||
|
# NVIDIA NIM / LiteLLM
|
||||||
|
"nvidia_nim/deepseek-ai/deepseek-v3.1": (0.27, 1.10),
|
||||||
|
"nvidia_nim/moonshotai/kimi-k2.5": (0.35, 1.40),
|
||||||
|
"nvidia_nim/minimaxai/minimax-m2.1": (0.30, 1.10),
|
||||||
|
# Fallback
|
||||||
|
"_default": (1.0, 3.0),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_pricing(model: str) -> tuple[float, float]:
|
||||||
|
"""Get (input_price, output_price) per 1M tokens for a model."""
|
||||||
|
return MODEL_PRICING.get(model, MODEL_PRICING["_default"])
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CostRecord:
|
||||||
|
model: str
|
||||||
|
input_tokens: int
|
||||||
|
output_tokens: int
|
||||||
|
cost_usd: float
|
||||||
|
timestamp: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BudgetConfig:
|
||||||
|
daily_limit_usd: float = 0.0 # 0 = no limit
|
||||||
|
monthly_limit_usd: float = 0.0 # 0 = no limit
|
||||||
|
warning_threshold: float = 0.8 # warn at 80% of budget
|
||||||
|
|
||||||
|
|
||||||
|
class CostTracker:
|
||||||
|
"""Track LLM costs in SQLite and enforce budget limits."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
db_path: str | Path = "data/costs.db",
|
||||||
|
budget: BudgetConfig | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._db_path = str(db_path)
|
||||||
|
self._budget = budget or BudgetConfig()
|
||||||
|
self._db: aiosqlite.Connection | None = None
|
||||||
|
|
||||||
|
async def setup(self) -> None:
|
||||||
|
Path(self._db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._db = await aiosqlite.connect(self._db_path)
|
||||||
|
await self._db.execute("PRAGMA journal_mode=WAL")
|
||||||
|
await self._db.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS cost_records (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
model TEXT NOT NULL,
|
||||||
|
input_tokens INTEGER NOT NULL,
|
||||||
|
output_tokens INTEGER NOT NULL,
|
||||||
|
cost_usd REAL NOT NULL,
|
||||||
|
timestamp REAL NOT NULL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
await self._db.commit()
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
if self._db:
|
||||||
|
await self._db.close()
|
||||||
|
self._db = None
|
||||||
|
|
||||||
|
async def record(self, model: str, usage: dict[str, int]) -> CostRecord:
|
||||||
|
"""Record token usage and return the cost record."""
|
||||||
|
input_tokens = usage.get("input_tokens", 0)
|
||||||
|
output_tokens = usage.get("output_tokens", 0)
|
||||||
|
input_price, output_price = _get_pricing(model)
|
||||||
|
cost = (input_tokens * input_price + output_tokens * output_price) / 1_000_000
|
||||||
|
|
||||||
|
record = CostRecord(
|
||||||
|
model=model,
|
||||||
|
input_tokens=input_tokens,
|
||||||
|
output_tokens=output_tokens,
|
||||||
|
cost_usd=cost,
|
||||||
|
timestamp=time.time(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._db:
|
||||||
|
await self._db.execute(
|
||||||
|
"INSERT INTO cost_records (model, input_tokens, output_tokens, cost_usd, timestamp) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
(model, input_tokens, output_tokens, cost, record.timestamp),
|
||||||
|
)
|
||||||
|
await self._db.commit()
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Cost: {model} — {input_tokens}in/{output_tokens}out = ${cost:.6f}"
|
||||||
|
)
|
||||||
|
return record
|
||||||
|
|
||||||
|
async def check_budget(self) -> str:
|
||||||
|
"""Check budget status. Returns 'ok', 'warning', or 'exceeded'."""
|
||||||
|
if not self._db:
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
# Daily check
|
||||||
|
if self._budget.daily_limit_usd > 0:
|
||||||
|
day_start = now - (now % 86400)
|
||||||
|
daily_total = await self._sum_costs_since(day_start)
|
||||||
|
if daily_total >= self._budget.daily_limit_usd:
|
||||||
|
logger.warning(f"Daily budget EXCEEDED: ${daily_total:.4f} / ${self._budget.daily_limit_usd:.2f}")
|
||||||
|
return "exceeded"
|
||||||
|
if daily_total >= self._budget.daily_limit_usd * self._budget.warning_threshold:
|
||||||
|
logger.warning(f"Daily budget warning: ${daily_total:.4f} / ${self._budget.daily_limit_usd:.2f}")
|
||||||
|
return "warning"
|
||||||
|
|
||||||
|
# Monthly check
|
||||||
|
if self._budget.monthly_limit_usd > 0:
|
||||||
|
month_start = now - (now % (86400 * 30))
|
||||||
|
monthly_total = await self._sum_costs_since(month_start)
|
||||||
|
if monthly_total >= self._budget.monthly_limit_usd:
|
||||||
|
logger.warning(f"Monthly budget EXCEEDED: ${monthly_total:.4f} / ${self._budget.monthly_limit_usd:.2f}")
|
||||||
|
return "exceeded"
|
||||||
|
if monthly_total >= self._budget.monthly_limit_usd * self._budget.warning_threshold:
|
||||||
|
logger.warning(f"Monthly budget warning: ${monthly_total:.4f} / ${self._budget.monthly_limit_usd:.2f}")
|
||||||
|
return "warning"
|
||||||
|
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
async def get_summary(self) -> dict:
|
||||||
|
"""Get a cost summary with per-model breakdown."""
|
||||||
|
if not self._db:
|
||||||
|
return {"total_usd": 0, "models": {}}
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
day_start = now - (now % 86400)
|
||||||
|
|
||||||
|
rows: list = []
|
||||||
|
async with self._db.execute(
|
||||||
|
"SELECT model, SUM(input_tokens), SUM(output_tokens), SUM(cost_usd) FROM cost_records WHERE timestamp >= ? GROUP BY model",
|
||||||
|
(day_start,),
|
||||||
|
) as cursor:
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
|
||||||
|
models = {}
|
||||||
|
total = 0.0
|
||||||
|
for model, inp, out, cost in rows:
|
||||||
|
models[model] = {"input_tokens": inp, "output_tokens": out, "cost_usd": cost}
|
||||||
|
total += cost
|
||||||
|
|
||||||
|
return {"total_usd": total, "models": models}
|
||||||
|
|
||||||
|
async def _sum_costs_since(self, since: float) -> float:
|
||||||
|
if not self._db:
|
||||||
|
return 0.0
|
||||||
|
async with self._db.execute(
|
||||||
|
"SELECT COALESCE(SUM(cost_usd), 0) FROM cost_records WHERE timestamp >= ?",
|
||||||
|
(since,),
|
||||||
|
) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
return row[0] if row else 0.0
|
||||||
@@ -2,17 +2,45 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from xtrm_agent.cache import ResponseCache
|
||||||
|
from xtrm_agent.classifier import ClassificationResult, QueryClassifier
|
||||||
from xtrm_agent.config import AgentFileConfig
|
from xtrm_agent.config import AgentFileConfig
|
||||||
import json
|
from xtrm_agent.cost import CostTracker
|
||||||
|
|
||||||
from xtrm_agent.llm.provider import LLMProvider, LLMResponse
|
from xtrm_agent.llm.provider import LLMProvider, LLMResponse
|
||||||
|
from xtrm_agent.scrub import scrub_credentials
|
||||||
from xtrm_agent.tools.approval import ApprovalEngine
|
from xtrm_agent.tools.approval import ApprovalEngine
|
||||||
from xtrm_agent.tools.registry import ToolRegistry
|
from xtrm_agent.tools.registry import ToolRegistry
|
||||||
|
|
||||||
|
# History compaction settings
|
||||||
|
_MAX_HISTORY_MESSAGES = 50
|
||||||
|
_KEEP_RECENT = 20
|
||||||
|
_COMPACTION_PROMPT = (
|
||||||
|
"Summarize the following conversation history concisely. "
|
||||||
|
"Preserve key facts, decisions, tool results, and context needed to continue. "
|
||||||
|
"Be brief but complete."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Retry settings
|
||||||
|
_MAX_RETRIES = 3
|
||||||
|
_RETRY_BACKOFF_BASE = 2.0
|
||||||
|
_NON_RETRYABLE_KEYWORDS = [
|
||||||
|
"authentication",
|
||||||
|
"unauthorized",
|
||||||
|
"invalid api key",
|
||||||
|
"model not found",
|
||||||
|
"permission denied",
|
||||||
|
"invalid_api_key",
|
||||||
|
"401",
|
||||||
|
"403",
|
||||||
|
"404",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class Engine:
|
class Engine:
|
||||||
"""Runs one agent's LLM loop: messages → LLM → tool calls → loop → response."""
|
"""Runs one agent's LLM loop: messages → LLM → tool calls → loop → response."""
|
||||||
@@ -23,16 +51,33 @@ class Engine:
|
|||||||
provider: LLMProvider,
|
provider: LLMProvider,
|
||||||
tools: ToolRegistry,
|
tools: ToolRegistry,
|
||||||
approval: ApprovalEngine,
|
approval: ApprovalEngine,
|
||||||
|
cache: ResponseCache | None = None,
|
||||||
|
cost_tracker: CostTracker | None = None,
|
||||||
|
classifier: QueryClassifier | None = None,
|
||||||
|
fallback_provider: LLMProvider | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.config = agent_config
|
self.config = agent_config
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
self.approval = approval
|
self.approval = approval
|
||||||
|
self.cache = cache
|
||||||
|
self.cost_tracker = cost_tracker
|
||||||
|
self.classifier = classifier
|
||||||
|
self.fallback_provider = fallback_provider
|
||||||
|
|
||||||
async def run(self, user_message: str) -> str:
|
async def run(self, user_message: str) -> str:
|
||||||
"""Process a single user message through the agent loop."""
|
"""Process a single user message through the agent loop."""
|
||||||
messages = self._build_initial_messages(user_message)
|
messages = self._build_initial_messages(user_message)
|
||||||
return await self._agent_loop(messages)
|
|
||||||
|
# Query classification — override model if classifier suggests one
|
||||||
|
model_override = ""
|
||||||
|
if self.classifier:
|
||||||
|
result = self.classifier.classify(user_message)
|
||||||
|
if result.model_hint:
|
||||||
|
model_override = result.model_hint
|
||||||
|
logger.debug(f"[{self.config.name}] Classified as '{result.category}' → model={result.model_hint}")
|
||||||
|
|
||||||
|
return await self._agent_loop(messages, model_override=model_override)
|
||||||
|
|
||||||
async def run_delegation(self, task: str) -> str:
|
async def run_delegation(self, task: str) -> str:
|
||||||
"""Process a delegation task (no system prompt changes)."""
|
"""Process a delegation task (no system prompt changes)."""
|
||||||
@@ -46,22 +91,45 @@ class Engine:
|
|||||||
messages.append({"role": "user", "content": user_message})
|
messages.append({"role": "user", "content": user_message})
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def _agent_loop(self, messages: list[dict[str, Any]]) -> str:
|
async def _agent_loop(
|
||||||
"""Core agent iteration loop."""
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
model_override: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Core agent iteration loop with caching, cost tracking, retry, and compaction."""
|
||||||
for iteration in range(self.config.max_iterations):
|
for iteration in range(self.config.max_iterations):
|
||||||
model = self.config.model or self.provider.get_default_model()
|
model = model_override or self.config.model or self.provider.get_default_model()
|
||||||
tool_defs = self.tools.get_definitions() if self.tools.names() else None
|
tool_defs = self.tools.get_definitions() if self.tools.names() else None
|
||||||
|
|
||||||
response = await self.provider.complete(
|
# Budget check
|
||||||
|
if self.cost_tracker:
|
||||||
|
status = await self.cost_tracker.check_budget()
|
||||||
|
if status == "exceeded":
|
||||||
|
return "(budget exceeded — please try again later)"
|
||||||
|
|
||||||
|
# Response cache check (only for first iteration — no tool calls yet)
|
||||||
|
if self.cache and iteration == 0 and not tool_defs:
|
||||||
|
cached = await self.cache.get(model, messages)
|
||||||
|
if cached:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# LLM call with retry + fallback
|
||||||
|
response = await self._call_with_retry(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tool_defs,
|
tools=tool_defs,
|
||||||
model=model,
|
model=model,
|
||||||
max_tokens=8192,
|
|
||||||
temperature=self.config.temperature,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Track cost
|
||||||
|
if self.cost_tracker and response.usage:
|
||||||
|
await self.cost_tracker.record(model, response.usage)
|
||||||
|
|
||||||
if not response.has_tool_calls:
|
if not response.has_tool_calls:
|
||||||
return response.content or "(no response)"
|
content = response.content or "(no response)"
|
||||||
|
# Cache the final response (only if no tool calls were used in the conversation)
|
||||||
|
if self.cache and iteration == 0:
|
||||||
|
await self.cache.put(model, messages, content)
|
||||||
|
return content
|
||||||
|
|
||||||
# Add assistant message with tool calls
|
# Add assistant message with tool calls
|
||||||
messages.append(self._assistant_message(response))
|
messages.append(self._assistant_message(response))
|
||||||
@@ -69,6 +137,8 @@ class Engine:
|
|||||||
# Execute each tool call
|
# Execute each tool call
|
||||||
for tc in response.tool_calls:
|
for tc in response.tool_calls:
|
||||||
result = await self._execute_tool(tc.name, tc.arguments)
|
result = await self._execute_tool(tc.name, tc.arguments)
|
||||||
|
# Scrub credentials from tool output
|
||||||
|
result = scrub_credentials(result)
|
||||||
messages.append(
|
messages.append(
|
||||||
{
|
{
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
@@ -83,8 +153,121 @@ class Engine:
|
|||||||
f"{len(response.tool_calls)} tool call(s)"
|
f"{len(response.tool_calls)} tool call(s)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# History compaction — summarize old messages if history is too long
|
||||||
|
messages = await self._compact_history(messages, model)
|
||||||
|
|
||||||
return "(max iterations reached)"
|
return "(max iterations reached)"
|
||||||
|
|
||||||
|
async def _call_with_retry(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None,
|
||||||
|
model: str,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Call the LLM with retry and fallback."""
|
||||||
|
last_error: Exception | None = None
|
||||||
|
|
||||||
|
for attempt in range(_MAX_RETRIES):
|
||||||
|
try:
|
||||||
|
return await self.provider.complete(
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
model=model,
|
||||||
|
max_tokens=8192,
|
||||||
|
temperature=self.config.temperature,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
error_str = str(e).lower()
|
||||||
|
|
||||||
|
# Don't retry non-retryable errors
|
||||||
|
if any(kw in error_str for kw in _NON_RETRYABLE_KEYWORDS):
|
||||||
|
logger.error(f"[{self.config.name}] Non-retryable error: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
wait = _RETRY_BACKOFF_BASE ** attempt
|
||||||
|
logger.warning(
|
||||||
|
f"[{self.config.name}] Attempt {attempt + 1}/{_MAX_RETRIES} failed: {e} — retrying in {wait}s"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(wait)
|
||||||
|
|
||||||
|
# Try fallback provider
|
||||||
|
if self.fallback_provider:
|
||||||
|
logger.info(f"[{self.config.name}] Trying fallback provider")
|
||||||
|
try:
|
||||||
|
return await self.fallback_provider.complete(
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
model=self.fallback_provider.get_default_model(),
|
||||||
|
max_tokens=8192,
|
||||||
|
temperature=self.config.temperature,
|
||||||
|
)
|
||||||
|
except Exception as fallback_err:
|
||||||
|
logger.error(f"[{self.config.name}] Fallback also failed: {fallback_err}")
|
||||||
|
|
||||||
|
# All retries and fallback exhausted
|
||||||
|
return LLMResponse(
|
||||||
|
content=f"(error after {_MAX_RETRIES} retries: {last_error})",
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _compact_history(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
model: str,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Compact old messages if history exceeds threshold."""
|
||||||
|
# Count non-system messages
|
||||||
|
non_system = [m for m in messages if m["role"] != "system"]
|
||||||
|
if len(non_system) <= _MAX_HISTORY_MESSAGES:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
# Split: keep system messages, summarize old, keep recent
|
||||||
|
system_msgs = [m for m in messages if m["role"] == "system"]
|
||||||
|
old_msgs = non_system[:-_KEEP_RECENT]
|
||||||
|
recent_msgs = non_system[-_KEEP_RECENT:]
|
||||||
|
|
||||||
|
# Ask LLM to summarize the old messages
|
||||||
|
summary_messages = [
|
||||||
|
{"role": "system", "content": _COMPACTION_PROMPT},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": json.dumps(
|
||||||
|
[{"role": m["role"], "content": m.get("content", "")} for m in old_msgs],
|
||||||
|
indent=None,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
summary_response = await self.provider.complete(
|
||||||
|
messages=summary_messages,
|
||||||
|
model=model,
|
||||||
|
max_tokens=2048,
|
||||||
|
temperature=0.1,
|
||||||
|
)
|
||||||
|
summary_text = summary_response.content or ""
|
||||||
|
logger.info(
|
||||||
|
f"[{self.config.name}] Compacted {len(old_msgs)} messages → "
|
||||||
|
f"{len(summary_text)} char summary"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[{self.config.name}] History compaction failed: {e}")
|
||||||
|
return messages
|
||||||
|
|
||||||
|
# Rebuild: system + summary + recent
|
||||||
|
compacted = list(system_msgs)
|
||||||
|
compacted.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": f"[Conversation summary]\n{summary_text}",
|
||||||
|
})
|
||||||
|
compacted.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Understood, I have the context from our previous conversation.",
|
||||||
|
})
|
||||||
|
compacted.extend(recent_msgs)
|
||||||
|
return compacted
|
||||||
|
|
||||||
async def _execute_tool(self, name: str, arguments: dict[str, Any]) -> str:
|
async def _execute_tool(self, name: str, arguments: dict[str, Any]) -> str:
|
||||||
"""Execute a tool with approval check."""
|
"""Execute a tool with approval check."""
|
||||||
approved = await self.approval.check(name, arguments)
|
approved = await self.approval.check(name, arguments)
|
||||||
|
|||||||
@@ -10,7 +10,10 @@ from typing import Any
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from xtrm_agent.bus import AgentMessage, InboundMessage, MessageBus, OutboundMessage
|
from xtrm_agent.bus import AgentMessage, InboundMessage, MessageBus, OutboundMessage
|
||||||
|
from xtrm_agent.cache import ResponseCache
|
||||||
|
from xtrm_agent.classifier import QueryClassifier
|
||||||
from xtrm_agent.config import Config, AgentFileConfig, parse_agent_file
|
from xtrm_agent.config import Config, AgentFileConfig, parse_agent_file
|
||||||
|
from xtrm_agent.cost import BudgetConfig, CostTracker
|
||||||
from xtrm_agent.engine import Engine
|
from xtrm_agent.engine import Engine
|
||||||
from xtrm_agent.llm.anthropic import AnthropicProvider
|
from xtrm_agent.llm.anthropic import AnthropicProvider
|
||||||
from xtrm_agent.llm.litellm import LiteLLMProvider
|
from xtrm_agent.llm.litellm import LiteLLMProvider
|
||||||
@@ -35,6 +38,8 @@ class Orchestrator:
|
|||||||
self._agent_configs: dict[str, AgentFileConfig] = {}
|
self._agent_configs: dict[str, AgentFileConfig] = {}
|
||||||
self._mcp_stack = AsyncExitStack()
|
self._mcp_stack = AsyncExitStack()
|
||||||
self._running = False
|
self._running = False
|
||||||
|
self._cache: ResponseCache | None = None
|
||||||
|
self._cost_tracker: CostTracker | None = None
|
||||||
|
|
||||||
# Channel defaults for routing
|
# Channel defaults for routing
|
||||||
channel_defaults = {}
|
channel_defaults = {}
|
||||||
@@ -53,6 +58,27 @@ class Orchestrator:
|
|||||||
workspace = Path(self.config.tools.workspace).resolve()
|
workspace = Path(self.config.tools.workspace).resolve()
|
||||||
workspace.mkdir(parents=True, exist_ok=True)
|
workspace.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Initialize shared response cache
|
||||||
|
self._cache = ResponseCache(
|
||||||
|
db_path=workspace / "cache.db",
|
||||||
|
ttl=self.config.performance.cache_ttl,
|
||||||
|
)
|
||||||
|
await self._cache.setup()
|
||||||
|
|
||||||
|
# Initialize cost tracker
|
||||||
|
budget = BudgetConfig(
|
||||||
|
daily_limit_usd=self.config.performance.daily_budget_usd,
|
||||||
|
monthly_limit_usd=self.config.performance.monthly_budget_usd,
|
||||||
|
)
|
||||||
|
self._cost_tracker = CostTracker(
|
||||||
|
db_path=workspace / "costs.db",
|
||||||
|
budget=budget,
|
||||||
|
)
|
||||||
|
await self._cost_tracker.setup()
|
||||||
|
|
||||||
|
# Initialize query classifier
|
||||||
|
classifier = QueryClassifier(model_map=self.config.performance.model_routing)
|
||||||
|
|
||||||
# Parse all agent definitions
|
# Parse all agent definitions
|
||||||
for agent_name, agent_path in self.config.agents.items():
|
for agent_name, agent_path in self.config.agents.items():
|
||||||
p = Path(agent_path)
|
p = Path(agent_path)
|
||||||
@@ -74,6 +100,10 @@ class Orchestrator:
|
|||||||
await self._mcp_stack.__aenter__()
|
await self._mcp_stack.__aenter__()
|
||||||
await connect_mcp_servers(self.config.mcp_servers, global_registry, self._mcp_stack)
|
await connect_mcp_servers(self.config.mcp_servers, global_registry, self._mcp_stack)
|
||||||
|
|
||||||
|
# Create fallback provider (LiteLLM with a cheap model)
|
||||||
|
fallback_model = self.config.performance.fallback_model
|
||||||
|
fallback_provider = LiteLLMProvider(model=fallback_model) if fallback_model else None
|
||||||
|
|
||||||
# Create one engine per agent
|
# Create one engine per agent
|
||||||
agent_names = list(self._agent_configs.keys())
|
agent_names = list(self._agent_configs.keys())
|
||||||
for agent_name, agent_cfg in self._agent_configs.items():
|
for agent_name, agent_cfg in self._agent_configs.items():
|
||||||
@@ -107,6 +137,10 @@ class Orchestrator:
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
tools=agent_registry,
|
tools=agent_registry,
|
||||||
approval=approval,
|
approval=approval,
|
||||||
|
cache=self._cache,
|
||||||
|
cost_tracker=self._cost_tracker,
|
||||||
|
classifier=classifier,
|
||||||
|
fallback_provider=fallback_provider,
|
||||||
)
|
)
|
||||||
self._engines[agent_name] = engine
|
self._engines[agent_name] = engine
|
||||||
|
|
||||||
@@ -136,6 +170,15 @@ class Orchestrator:
|
|||||||
return f"Error: Agent '{agent_name}' not found"
|
return f"Error: Agent '{agent_name}' not found"
|
||||||
|
|
||||||
content = self.router.strip_mention(msg.content) if msg.content.startswith("@") else msg.content
|
content = self.router.strip_mention(msg.content) if msg.content.startswith("@") else msg.content
|
||||||
|
|
||||||
|
# Prepend attachment content so the LLM can see it
|
||||||
|
if msg.attachments:
|
||||||
|
parts: list[str] = []
|
||||||
|
for att in msg.attachments:
|
||||||
|
parts.append(f"[Attached file: {att.filename}]\n{att.content}")
|
||||||
|
parts.append(content)
|
||||||
|
content = "\n\n".join(parts)
|
||||||
|
|
||||||
logger.info(f"[{agent_name}] Processing: {content[:80]}")
|
logger.info(f"[{agent_name}] Processing: {content[:80]}")
|
||||||
return await engine.run(content)
|
return await engine.run(content)
|
||||||
|
|
||||||
@@ -181,6 +224,10 @@ class Orchestrator:
|
|||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
self._running = False
|
self._running = False
|
||||||
|
if self._cache:
|
||||||
|
await self._cache.close()
|
||||||
|
if self._cost_tracker:
|
||||||
|
await self._cost_tracker.close()
|
||||||
await self._mcp_stack.aclose()
|
await self._mcp_stack.aclose()
|
||||||
logger.info("Orchestrator stopped")
|
logger.info("Orchestrator stopped")
|
||||||
|
|
||||||
|
|||||||
46
xtrm_agent/scrub.py
Normal file
46
xtrm_agent/scrub.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""Credential scrubbing — prevent secret leakage in tool output."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Patterns that match common secret formats
|
||||||
|
_SECRET_PATTERNS = [
|
||||||
|
# Key=value patterns (API keys, tokens, passwords)
|
||||||
|
re.compile(
|
||||||
|
r"(?i)(api[_-]?key|token|password|passwd|secret|access[_-]?key|private[_-]?key|auth)"
|
||||||
|
r"[\s]*[=:]\s*['\"]?([^\s'\"]{8,})['\"]?",
|
||||||
|
),
|
||||||
|
# Bearer tokens
|
||||||
|
re.compile(r"Bearer\s+[A-Za-z0-9\-._~+/]+=*", re.IGNORECASE),
|
||||||
|
# AWS access keys (AKIA...)
|
||||||
|
re.compile(r"AKIA[0-9A-Z]{16}"),
|
||||||
|
# AWS secret keys (40 char base64)
|
||||||
|
re.compile(r"(?i)aws[_-]?secret[_-]?access[_-]?key[\s]*[=:]\s*['\"]?([A-Za-z0-9/+=]{40})['\"]?"),
|
||||||
|
# GitHub tokens
|
||||||
|
re.compile(r"gh[pousr]_[A-Za-z0-9_]{36,}"),
|
||||||
|
# Generic long hex strings that look like secrets (32+ hex chars after key= or token=)
|
||||||
|
re.compile(r"(?i)(?:key|token|secret)[=:]\s*['\"]?([0-9a-f]{32,})['\"]?"),
|
||||||
|
]
|
||||||
|
|
||||||
|
_REDACTED = "[REDACTED]"
|
||||||
|
|
||||||
|
|
||||||
|
def scrub_credentials(text: str) -> str:
|
||||||
|
"""Scrub potential secrets from text, replacing with [REDACTED]."""
|
||||||
|
result = text
|
||||||
|
for pattern in _SECRET_PATTERNS:
|
||||||
|
result = pattern.sub(_redact_match, result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _redact_match(match: re.Match) -> str:
|
||||||
|
"""Replace the secret value while keeping the key name visible."""
|
||||||
|
full = match.group(0)
|
||||||
|
# For key=value patterns, keep the key part
|
||||||
|
for sep in ("=", ":"):
|
||||||
|
if sep in full:
|
||||||
|
key_part = full[: full.index(sep) + 1]
|
||||||
|
return f"{key_part} {_REDACTED}"
|
||||||
|
# For standalone patterns (Bearer, AKIA, gh*_), redact the whole thing
|
||||||
|
return _REDACTED
|
||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -15,7 +16,7 @@ class ApprovalPolicy(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class ApprovalEngine:
|
class ApprovalEngine:
|
||||||
"""Deny-by-default tool approval."""
|
"""Deny-by-default tool approval with session allowlist and audit log."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -26,6 +27,8 @@ class ApprovalEngine:
|
|||||||
self._auto_approve = set(auto_approve or [])
|
self._auto_approve = set(auto_approve or [])
|
||||||
self._require_approval = set(require_approval or [])
|
self._require_approval = set(require_approval or [])
|
||||||
self._interactive = interactive
|
self._interactive = interactive
|
||||||
|
self._session_allowed: set[str] = set()
|
||||||
|
self._audit_log: list[dict[str, Any]] = []
|
||||||
|
|
||||||
def get_policy(self, tool_name: str) -> ApprovalPolicy:
|
def get_policy(self, tool_name: str) -> ApprovalPolicy:
|
||||||
"""Get the approval policy for a tool."""
|
"""Get the approval policy for a tool."""
|
||||||
@@ -43,27 +46,54 @@ class ApprovalEngine:
|
|||||||
policy = self.get_policy(tool_name)
|
policy = self.get_policy(tool_name)
|
||||||
|
|
||||||
if policy == ApprovalPolicy.AUTO_APPROVE:
|
if policy == ApprovalPolicy.AUTO_APPROVE:
|
||||||
|
self._log_decision(tool_name, arguments, "auto_approved")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Session-scoped "always allow"
|
||||||
|
if tool_name in self._session_allowed:
|
||||||
|
self._log_decision(tool_name, arguments, "session_allowed")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if policy == ApprovalPolicy.DENY:
|
if policy == ApprovalPolicy.DENY:
|
||||||
logger.warning(f"Tool '{tool_name}' denied by policy")
|
logger.warning(f"Tool '{tool_name}' denied by policy")
|
||||||
|
self._log_decision(tool_name, arguments, "denied")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# REQUIRE_APPROVAL
|
# REQUIRE_APPROVAL
|
||||||
if not self._interactive:
|
if not self._interactive:
|
||||||
logger.warning(f"Tool '{tool_name}' requires approval but running non-interactively — denied")
|
logger.warning(f"Tool '{tool_name}' requires approval but running non-interactively — denied")
|
||||||
|
self._log_decision(tool_name, arguments, "denied_non_interactive")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# In interactive mode, prompt the user
|
# In interactive mode, prompt the user
|
||||||
logger.info(f"Tool '{tool_name}' requires approval. Args: {arguments}")
|
logger.info(f"Tool '{tool_name}' requires approval. Args: {arguments}")
|
||||||
return await self._prompt_user(tool_name, arguments)
|
approved, always = await self._prompt_user(tool_name, arguments)
|
||||||
|
if approved and always:
|
||||||
|
self._session_allowed.add(tool_name)
|
||||||
|
self._log_decision(tool_name, arguments, "user_approved" if approved else "user_denied")
|
||||||
|
return approved
|
||||||
|
|
||||||
async def _prompt_user(self, tool_name: str, arguments: dict[str, Any]) -> bool:
|
async def _prompt_user(self, tool_name: str, arguments: dict[str, Any]) -> tuple[bool, bool]:
|
||||||
"""Prompt user for tool approval (interactive mode)."""
|
"""Prompt user for tool approval. Returns (approved, always_allow)."""
|
||||||
print(f"\n[APPROVAL REQUIRED] Tool: {tool_name}")
|
print(f"\n[APPROVAL REQUIRED] Tool: {tool_name}")
|
||||||
print(f" Arguments: {arguments}")
|
print(f" Arguments: {arguments}")
|
||||||
try:
|
try:
|
||||||
answer = input(" Allow? [y/N]: ").strip().lower()
|
answer = input(" Allow? [y/N/a(lways)]: ").strip().lower()
|
||||||
return answer in ("y", "yes")
|
if answer in ("a", "always"):
|
||||||
|
return True, True
|
||||||
|
return answer in ("y", "yes"), False
|
||||||
except (EOFError, KeyboardInterrupt):
|
except (EOFError, KeyboardInterrupt):
|
||||||
return False
|
return False, False
|
||||||
|
|
||||||
|
def _log_decision(self, tool_name: str, arguments: dict[str, Any], decision: str) -> None:
|
||||||
|
"""Record an approval decision in the audit log."""
|
||||||
|
self._audit_log.append({
|
||||||
|
"tool": tool_name,
|
||||||
|
"arguments": arguments,
|
||||||
|
"decision": decision,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
})
|
||||||
|
|
||||||
|
def get_audit_log(self) -> list[dict[str, Any]]:
|
||||||
|
"""Return the audit log for inspection."""
|
||||||
|
return list(self._audit_log)
|
||||||
|
|||||||
@@ -216,6 +216,26 @@ class BashTool(Tool):
|
|||||||
return f"Error: Command timed out after {self._timeout}s"
|
return f"Error: Command timed out after {self._timeout}s"
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_html(html: str) -> str:
|
||||||
|
"""Strip HTML tags and collapse whitespace to get readable text."""
|
||||||
|
# Remove script and style blocks
|
||||||
|
text = re.sub(r"<(script|style)[^>]*>.*?</\1>", "", html, flags=re.DOTALL | re.IGNORECASE)
|
||||||
|
# Replace <br>, <p>, <div>, <li> etc. with newlines
|
||||||
|
text = re.sub(r"<(br|p|div|li|h[1-6]|tr)[^>]*/?>", "\n", text, flags=re.IGNORECASE)
|
||||||
|
# Strip remaining tags
|
||||||
|
text = re.sub(r"<[^>]+>", "", text)
|
||||||
|
# Decode common HTML entities
|
||||||
|
text = text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||||
|
text = text.replace(""", '"').replace("'", "'").replace(" ", " ")
|
||||||
|
# Collapse whitespace
|
||||||
|
text = re.sub(r"[ \t]+", " ", text)
|
||||||
|
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
_WEB_USER_AGENT = "Mozilla/5.0 (compatible; XtrmAgent/1.0; +https://github.com)"
|
||||||
|
|
||||||
|
|
||||||
class WebFetchTool(Tool):
|
class WebFetchTool(Tool):
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@@ -223,7 +243,7 @@ class WebFetchTool(Tool):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return "Fetch the content of a URL."
|
return "Fetch the content of a URL and return it as readable text."
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
@@ -237,9 +257,14 @@ class WebFetchTool(Tool):
|
|||||||
|
|
||||||
async def execute(self, url: str, **_: Any) -> str:
|
async def execute(self, url: str, **_: Any) -> str:
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client:
|
async with httpx.AsyncClient(
|
||||||
|
timeout=30, follow_redirects=True, headers={"User-Agent": _WEB_USER_AGENT}
|
||||||
|
) as client:
|
||||||
resp = await client.get(url)
|
resp = await client.get(url)
|
||||||
|
content_type = resp.headers.get("content-type", "")
|
||||||
text = resp.text
|
text = resp.text
|
||||||
|
if "html" in content_type:
|
||||||
|
text = _strip_html(text)
|
||||||
if len(text) > 20_000:
|
if len(text) > 20_000:
|
||||||
text = text[:20_000] + "\n... (truncated)"
|
text = text[:20_000] + "\n... (truncated)"
|
||||||
return text
|
return text
|
||||||
@@ -247,6 +272,51 @@ class WebFetchTool(Tool):
|
|||||||
return f"Error fetching URL: {e}"
|
return f"Error fetching URL: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
class WebSearchTool(Tool):
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "web_search"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Search the web using DuckDuckGo and return a list of results with title, URL, and snippet."
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string", "description": "Search query"},
|
||||||
|
"max_results": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Maximum number of results (default: 5)",
|
||||||
|
"default": 5,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, query: str, max_results: int = 5, **_: Any) -> str:
|
||||||
|
try:
|
||||||
|
from duckduckgo_search import AsyncDDGS
|
||||||
|
|
||||||
|
async with AsyncDDGS() as ddgs:
|
||||||
|
results = await ddgs.atext(query, max_results=max_results)
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return "No results found."
|
||||||
|
|
||||||
|
lines: list[str] = []
|
||||||
|
for r in results:
|
||||||
|
lines.append(f"**{r.get('title', '')}**")
|
||||||
|
lines.append(r.get("href", ""))
|
||||||
|
lines.append(r.get("body", ""))
|
||||||
|
lines.append("---")
|
||||||
|
return "\n".join(lines)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error searching: {e}"
|
||||||
|
|
||||||
|
|
||||||
def register_builtin_tools(registry: Any, workspace: Path) -> None:
|
def register_builtin_tools(registry: Any, workspace: Path) -> None:
|
||||||
"""Register all built-in tools into a ToolRegistry."""
|
"""Register all built-in tools into a ToolRegistry."""
|
||||||
registry.register(ReadFileTool(workspace))
|
registry.register(ReadFileTool(workspace))
|
||||||
@@ -255,3 +325,4 @@ def register_builtin_tools(registry: Any, workspace: Path) -> None:
|
|||||||
registry.register(ListDirTool(workspace))
|
registry.register(ListDirTool(workspace))
|
||||||
registry.register(BashTool(workspace))
|
registry.register(BashTool(workspace))
|
||||||
registry.register(WebFetchTool())
|
registry.register(WebFetchTool())
|
||||||
|
registry.register(WebSearchTool())
|
||||||
|
|||||||
Reference in New Issue
Block a user