Add performance features: caching, cost tracking, retry, compaction, classification, scrubbing
Inspired by zeroclaw's lightweight patterns for slow hardware: - Response cache (SQLite + SHA-256 keyed) to skip redundant LLM calls - History compaction — LLM-summarize old messages when history exceeds 50 - Query classifier routes simple/research queries to cheaper models - Credential scrubbing removes secrets from tool output before sending to LLM - Cost tracker with daily/monthly budget enforcement (SQLite) - Resilient provider with retry + exponential backoff + fallback provider - Approval engine gains session "always allow" and audit log Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -2,17 +2,45 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from xtrm_agent.cache import ResponseCache
|
||||
from xtrm_agent.classifier import ClassificationResult, QueryClassifier
|
||||
from xtrm_agent.config import AgentFileConfig
|
||||
import json
|
||||
|
||||
from xtrm_agent.cost import CostTracker
|
||||
from xtrm_agent.llm.provider import LLMProvider, LLMResponse
|
||||
from xtrm_agent.scrub import scrub_credentials
|
||||
from xtrm_agent.tools.approval import ApprovalEngine
|
||||
from xtrm_agent.tools.registry import ToolRegistry
|
||||
|
||||
# History compaction settings
|
||||
_MAX_HISTORY_MESSAGES = 50
|
||||
_KEEP_RECENT = 20
|
||||
_COMPACTION_PROMPT = (
|
||||
"Summarize the following conversation history concisely. "
|
||||
"Preserve key facts, decisions, tool results, and context needed to continue. "
|
||||
"Be brief but complete."
|
||||
)
|
||||
|
||||
# Retry settings
|
||||
_MAX_RETRIES = 3
|
||||
_RETRY_BACKOFF_BASE = 2.0
|
||||
_NON_RETRYABLE_KEYWORDS = [
|
||||
"authentication",
|
||||
"unauthorized",
|
||||
"invalid api key",
|
||||
"model not found",
|
||||
"permission denied",
|
||||
"invalid_api_key",
|
||||
"401",
|
||||
"403",
|
||||
"404",
|
||||
]
|
||||
|
||||
|
||||
class Engine:
|
||||
"""Runs one agent's LLM loop: messages → LLM → tool calls → loop → response."""
|
||||
@@ -23,16 +51,33 @@ class Engine:
|
||||
provider: LLMProvider,
|
||||
tools: ToolRegistry,
|
||||
approval: ApprovalEngine,
|
||||
cache: ResponseCache | None = None,
|
||||
cost_tracker: CostTracker | None = None,
|
||||
classifier: QueryClassifier | None = None,
|
||||
fallback_provider: LLMProvider | None = None,
|
||||
) -> None:
|
||||
self.config = agent_config
|
||||
self.provider = provider
|
||||
self.tools = tools
|
||||
self.approval = approval
|
||||
self.cache = cache
|
||||
self.cost_tracker = cost_tracker
|
||||
self.classifier = classifier
|
||||
self.fallback_provider = fallback_provider
|
||||
|
||||
async def run(self, user_message: str) -> str:
|
||||
"""Process a single user message through the agent loop."""
|
||||
messages = self._build_initial_messages(user_message)
|
||||
return await self._agent_loop(messages)
|
||||
|
||||
# Query classification — override model if classifier suggests one
|
||||
model_override = ""
|
||||
if self.classifier:
|
||||
result = self.classifier.classify(user_message)
|
||||
if result.model_hint:
|
||||
model_override = result.model_hint
|
||||
logger.debug(f"[{self.config.name}] Classified as '{result.category}' → model={result.model_hint}")
|
||||
|
||||
return await self._agent_loop(messages, model_override=model_override)
|
||||
|
||||
async def run_delegation(self, task: str) -> str:
|
||||
"""Process a delegation task (no system prompt changes)."""
|
||||
@@ -46,22 +91,45 @@ class Engine:
|
||||
messages.append({"role": "user", "content": user_message})
|
||||
return messages
|
||||
|
||||
async def _agent_loop(self, messages: list[dict[str, Any]]) -> str:
|
||||
"""Core agent iteration loop."""
|
||||
async def _agent_loop(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
model_override: str = "",
|
||||
) -> str:
|
||||
"""Core agent iteration loop with caching, cost tracking, retry, and compaction."""
|
||||
for iteration in range(self.config.max_iterations):
|
||||
model = self.config.model or self.provider.get_default_model()
|
||||
model = model_override or self.config.model or self.provider.get_default_model()
|
||||
tool_defs = self.tools.get_definitions() if self.tools.names() else None
|
||||
|
||||
response = await self.provider.complete(
|
||||
# Budget check
|
||||
if self.cost_tracker:
|
||||
status = await self.cost_tracker.check_budget()
|
||||
if status == "exceeded":
|
||||
return "(budget exceeded — please try again later)"
|
||||
|
||||
# Response cache check (only for first iteration — no tool calls yet)
|
||||
if self.cache and iteration == 0 and not tool_defs:
|
||||
cached = await self.cache.get(model, messages)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
# LLM call with retry + fallback
|
||||
response = await self._call_with_retry(
|
||||
messages=messages,
|
||||
tools=tool_defs,
|
||||
model=model,
|
||||
max_tokens=8192,
|
||||
temperature=self.config.temperature,
|
||||
)
|
||||
|
||||
# Track cost
|
||||
if self.cost_tracker and response.usage:
|
||||
await self.cost_tracker.record(model, response.usage)
|
||||
|
||||
if not response.has_tool_calls:
|
||||
return response.content or "(no response)"
|
||||
content = response.content or "(no response)"
|
||||
# Cache the final response (only if no tool calls were used in the conversation)
|
||||
if self.cache and iteration == 0:
|
||||
await self.cache.put(model, messages, content)
|
||||
return content
|
||||
|
||||
# Add assistant message with tool calls
|
||||
messages.append(self._assistant_message(response))
|
||||
@@ -69,6 +137,8 @@ class Engine:
|
||||
# Execute each tool call
|
||||
for tc in response.tool_calls:
|
||||
result = await self._execute_tool(tc.name, tc.arguments)
|
||||
# Scrub credentials from tool output
|
||||
result = scrub_credentials(result)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
@@ -83,8 +153,121 @@ class Engine:
|
||||
f"{len(response.tool_calls)} tool call(s)"
|
||||
)
|
||||
|
||||
# History compaction — summarize old messages if history is too long
|
||||
messages = await self._compact_history(messages, model)
|
||||
|
||||
return "(max iterations reached)"
|
||||
|
||||
async def _call_with_retry(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None,
|
||||
model: str,
|
||||
) -> LLMResponse:
|
||||
"""Call the LLM with retry and fallback."""
|
||||
last_error: Exception | None = None
|
||||
|
||||
for attempt in range(_MAX_RETRIES):
|
||||
try:
|
||||
return await self.provider.complete(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=8192,
|
||||
temperature=self.config.temperature,
|
||||
)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
error_str = str(e).lower()
|
||||
|
||||
# Don't retry non-retryable errors
|
||||
if any(kw in error_str for kw in _NON_RETRYABLE_KEYWORDS):
|
||||
logger.error(f"[{self.config.name}] Non-retryable error: {e}")
|
||||
break
|
||||
|
||||
wait = _RETRY_BACKOFF_BASE ** attempt
|
||||
logger.warning(
|
||||
f"[{self.config.name}] Attempt {attempt + 1}/{_MAX_RETRIES} failed: {e} — retrying in {wait}s"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
|
||||
# Try fallback provider
|
||||
if self.fallback_provider:
|
||||
logger.info(f"[{self.config.name}] Trying fallback provider")
|
||||
try:
|
||||
return await self.fallback_provider.complete(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=self.fallback_provider.get_default_model(),
|
||||
max_tokens=8192,
|
||||
temperature=self.config.temperature,
|
||||
)
|
||||
except Exception as fallback_err:
|
||||
logger.error(f"[{self.config.name}] Fallback also failed: {fallback_err}")
|
||||
|
||||
# All retries and fallback exhausted
|
||||
return LLMResponse(
|
||||
content=f"(error after {_MAX_RETRIES} retries: {last_error})",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
async def _compact_history(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
model: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Compact old messages if history exceeds threshold."""
|
||||
# Count non-system messages
|
||||
non_system = [m for m in messages if m["role"] != "system"]
|
||||
if len(non_system) <= _MAX_HISTORY_MESSAGES:
|
||||
return messages
|
||||
|
||||
# Split: keep system messages, summarize old, keep recent
|
||||
system_msgs = [m for m in messages if m["role"] == "system"]
|
||||
old_msgs = non_system[:-_KEEP_RECENT]
|
||||
recent_msgs = non_system[-_KEEP_RECENT:]
|
||||
|
||||
# Ask LLM to summarize the old messages
|
||||
summary_messages = [
|
||||
{"role": "system", "content": _COMPACTION_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": json.dumps(
|
||||
[{"role": m["role"], "content": m.get("content", "")} for m in old_msgs],
|
||||
indent=None,
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
try:
|
||||
summary_response = await self.provider.complete(
|
||||
messages=summary_messages,
|
||||
model=model,
|
||||
max_tokens=2048,
|
||||
temperature=0.1,
|
||||
)
|
||||
summary_text = summary_response.content or ""
|
||||
logger.info(
|
||||
f"[{self.config.name}] Compacted {len(old_msgs)} messages → "
|
||||
f"{len(summary_text)} char summary"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.config.name}] History compaction failed: {e}")
|
||||
return messages
|
||||
|
||||
# Rebuild: system + summary + recent
|
||||
compacted = list(system_msgs)
|
||||
compacted.append({
|
||||
"role": "user",
|
||||
"content": f"[Conversation summary]\n{summary_text}",
|
||||
})
|
||||
compacted.append({
|
||||
"role": "assistant",
|
||||
"content": "Understood, I have the context from our previous conversation.",
|
||||
})
|
||||
compacted.extend(recent_msgs)
|
||||
return compacted
|
||||
|
||||
async def _execute_tool(self, name: str, arguments: dict[str, Any]) -> str:
|
||||
"""Execute a tool with approval check."""
|
||||
approved = await self.approval.check(name, arguments)
|
||||
|
||||
Reference in New Issue
Block a user