Inspired by zeroclaw's lightweight patterns for slow hardware: - Response cache (SQLite + SHA-256 keyed) to skip redundant LLM calls - History compaction — LLM-summarize old messages when history exceeds 50 - Query classifier routes simple/research queries to cheaper models - Credential scrubbing removes secrets from tool output before sending to LLM - Cost tracker with daily/monthly budget enforcement (SQLite) - Resilient provider with retry + exponential backoff + fallback provider - Approval engine gains session "always allow" and audit log Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
293 lines
11 KiB
Python
293 lines
11 KiB
Python
"""Single agent engine — one LLM loop per agent."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
from typing import Any
|
|
|
|
from loguru import logger
|
|
|
|
from xtrm_agent.cache import ResponseCache
|
|
from xtrm_agent.classifier import ClassificationResult, QueryClassifier
|
|
from xtrm_agent.config import AgentFileConfig
|
|
from xtrm_agent.cost import CostTracker
|
|
from xtrm_agent.llm.provider import LLMProvider, LLMResponse
|
|
from xtrm_agent.scrub import scrub_credentials
|
|
from xtrm_agent.tools.approval import ApprovalEngine
|
|
from xtrm_agent.tools.registry import ToolRegistry
|
|
|
|
# History compaction settings
|
|
_MAX_HISTORY_MESSAGES = 50
|
|
_KEEP_RECENT = 20
|
|
_COMPACTION_PROMPT = (
|
|
"Summarize the following conversation history concisely. "
|
|
"Preserve key facts, decisions, tool results, and context needed to continue. "
|
|
"Be brief but complete."
|
|
)
|
|
|
|
# Retry settings
|
|
_MAX_RETRIES = 3
|
|
_RETRY_BACKOFF_BASE = 2.0
|
|
_NON_RETRYABLE_KEYWORDS = [
|
|
"authentication",
|
|
"unauthorized",
|
|
"invalid api key",
|
|
"model not found",
|
|
"permission denied",
|
|
"invalid_api_key",
|
|
"401",
|
|
"403",
|
|
"404",
|
|
]
|
|
|
|
|
|
class Engine:
|
|
"""Runs one agent's LLM loop: messages → LLM → tool calls → loop → response."""
|
|
|
|
def __init__(
|
|
self,
|
|
agent_config: AgentFileConfig,
|
|
provider: LLMProvider,
|
|
tools: ToolRegistry,
|
|
approval: ApprovalEngine,
|
|
cache: ResponseCache | None = None,
|
|
cost_tracker: CostTracker | None = None,
|
|
classifier: QueryClassifier | None = None,
|
|
fallback_provider: LLMProvider | None = None,
|
|
) -> None:
|
|
self.config = agent_config
|
|
self.provider = provider
|
|
self.tools = tools
|
|
self.approval = approval
|
|
self.cache = cache
|
|
self.cost_tracker = cost_tracker
|
|
self.classifier = classifier
|
|
self.fallback_provider = fallback_provider
|
|
|
|
async def run(self, user_message: str) -> str:
|
|
"""Process a single user message through the agent loop."""
|
|
messages = self._build_initial_messages(user_message)
|
|
|
|
# Query classification — override model if classifier suggests one
|
|
model_override = ""
|
|
if self.classifier:
|
|
result = self.classifier.classify(user_message)
|
|
if result.model_hint:
|
|
model_override = result.model_hint
|
|
logger.debug(f"[{self.config.name}] Classified as '{result.category}' → model={result.model_hint}")
|
|
|
|
return await self._agent_loop(messages, model_override=model_override)
|
|
|
|
async def run_delegation(self, task: str) -> str:
|
|
"""Process a delegation task (no system prompt changes)."""
|
|
messages = self._build_initial_messages(task)
|
|
return await self._agent_loop(messages)
|
|
|
|
def _build_initial_messages(self, user_message: str) -> list[dict[str, Any]]:
|
|
messages: list[dict[str, Any]] = []
|
|
if self.config.instructions:
|
|
messages.append({"role": "system", "content": self.config.instructions})
|
|
messages.append({"role": "user", "content": user_message})
|
|
return messages
|
|
|
|
async def _agent_loop(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
model_override: str = "",
|
|
) -> str:
|
|
"""Core agent iteration loop with caching, cost tracking, retry, and compaction."""
|
|
for iteration in range(self.config.max_iterations):
|
|
model = model_override or self.config.model or self.provider.get_default_model()
|
|
tool_defs = self.tools.get_definitions() if self.tools.names() else None
|
|
|
|
# Budget check
|
|
if self.cost_tracker:
|
|
status = await self.cost_tracker.check_budget()
|
|
if status == "exceeded":
|
|
return "(budget exceeded — please try again later)"
|
|
|
|
# Response cache check (only for first iteration — no tool calls yet)
|
|
if self.cache and iteration == 0 and not tool_defs:
|
|
cached = await self.cache.get(model, messages)
|
|
if cached:
|
|
return cached
|
|
|
|
# LLM call with retry + fallback
|
|
response = await self._call_with_retry(
|
|
messages=messages,
|
|
tools=tool_defs,
|
|
model=model,
|
|
)
|
|
|
|
# Track cost
|
|
if self.cost_tracker and response.usage:
|
|
await self.cost_tracker.record(model, response.usage)
|
|
|
|
if not response.has_tool_calls:
|
|
content = response.content or "(no response)"
|
|
# Cache the final response (only if no tool calls were used in the conversation)
|
|
if self.cache and iteration == 0:
|
|
await self.cache.put(model, messages, content)
|
|
return content
|
|
|
|
# Add assistant message with tool calls
|
|
messages.append(self._assistant_message(response))
|
|
|
|
# Execute each tool call
|
|
for tc in response.tool_calls:
|
|
result = await self._execute_tool(tc.name, tc.arguments)
|
|
# Scrub credentials from tool output
|
|
result = scrub_credentials(result)
|
|
messages.append(
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": tc.id,
|
|
"name": tc.name,
|
|
"content": result,
|
|
}
|
|
)
|
|
|
|
logger.debug(
|
|
f"[{self.config.name}] Iteration {iteration + 1}: "
|
|
f"{len(response.tool_calls)} tool call(s)"
|
|
)
|
|
|
|
# History compaction — summarize old messages if history is too long
|
|
messages = await self._compact_history(messages, model)
|
|
|
|
return "(max iterations reached)"
|
|
|
|
async def _call_with_retry(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None,
|
|
model: str,
|
|
) -> LLMResponse:
|
|
"""Call the LLM with retry and fallback."""
|
|
last_error: Exception | None = None
|
|
|
|
for attempt in range(_MAX_RETRIES):
|
|
try:
|
|
return await self.provider.complete(
|
|
messages=messages,
|
|
tools=tools,
|
|
model=model,
|
|
max_tokens=8192,
|
|
temperature=self.config.temperature,
|
|
)
|
|
except Exception as e:
|
|
last_error = e
|
|
error_str = str(e).lower()
|
|
|
|
# Don't retry non-retryable errors
|
|
if any(kw in error_str for kw in _NON_RETRYABLE_KEYWORDS):
|
|
logger.error(f"[{self.config.name}] Non-retryable error: {e}")
|
|
break
|
|
|
|
wait = _RETRY_BACKOFF_BASE ** attempt
|
|
logger.warning(
|
|
f"[{self.config.name}] Attempt {attempt + 1}/{_MAX_RETRIES} failed: {e} — retrying in {wait}s"
|
|
)
|
|
await asyncio.sleep(wait)
|
|
|
|
# Try fallback provider
|
|
if self.fallback_provider:
|
|
logger.info(f"[{self.config.name}] Trying fallback provider")
|
|
try:
|
|
return await self.fallback_provider.complete(
|
|
messages=messages,
|
|
tools=tools,
|
|
model=self.fallback_provider.get_default_model(),
|
|
max_tokens=8192,
|
|
temperature=self.config.temperature,
|
|
)
|
|
except Exception as fallback_err:
|
|
logger.error(f"[{self.config.name}] Fallback also failed: {fallback_err}")
|
|
|
|
# All retries and fallback exhausted
|
|
return LLMResponse(
|
|
content=f"(error after {_MAX_RETRIES} retries: {last_error})",
|
|
finish_reason="error",
|
|
)
|
|
|
|
async def _compact_history(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
model: str,
|
|
) -> list[dict[str, Any]]:
|
|
"""Compact old messages if history exceeds threshold."""
|
|
# Count non-system messages
|
|
non_system = [m for m in messages if m["role"] != "system"]
|
|
if len(non_system) <= _MAX_HISTORY_MESSAGES:
|
|
return messages
|
|
|
|
# Split: keep system messages, summarize old, keep recent
|
|
system_msgs = [m for m in messages if m["role"] == "system"]
|
|
old_msgs = non_system[:-_KEEP_RECENT]
|
|
recent_msgs = non_system[-_KEEP_RECENT:]
|
|
|
|
# Ask LLM to summarize the old messages
|
|
summary_messages = [
|
|
{"role": "system", "content": _COMPACTION_PROMPT},
|
|
{
|
|
"role": "user",
|
|
"content": json.dumps(
|
|
[{"role": m["role"], "content": m.get("content", "")} for m in old_msgs],
|
|
indent=None,
|
|
),
|
|
},
|
|
]
|
|
|
|
try:
|
|
summary_response = await self.provider.complete(
|
|
messages=summary_messages,
|
|
model=model,
|
|
max_tokens=2048,
|
|
temperature=0.1,
|
|
)
|
|
summary_text = summary_response.content or ""
|
|
logger.info(
|
|
f"[{self.config.name}] Compacted {len(old_msgs)} messages → "
|
|
f"{len(summary_text)} char summary"
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"[{self.config.name}] History compaction failed: {e}")
|
|
return messages
|
|
|
|
# Rebuild: system + summary + recent
|
|
compacted = list(system_msgs)
|
|
compacted.append({
|
|
"role": "user",
|
|
"content": f"[Conversation summary]\n{summary_text}",
|
|
})
|
|
compacted.append({
|
|
"role": "assistant",
|
|
"content": "Understood, I have the context from our previous conversation.",
|
|
})
|
|
compacted.extend(recent_msgs)
|
|
return compacted
|
|
|
|
async def _execute_tool(self, name: str, arguments: dict[str, Any]) -> str:
|
|
"""Execute a tool with approval check."""
|
|
approved = await self.approval.check(name, arguments)
|
|
if not approved:
|
|
return f"Tool '{name}' was denied by approval policy."
|
|
return await self.tools.execute(name, arguments)
|
|
|
|
def _assistant_message(self, response: LLMResponse) -> dict[str, Any]:
|
|
"""Build assistant message dict from LLMResponse."""
|
|
msg: dict[str, Any] = {"role": "assistant"}
|
|
if response.content:
|
|
msg["content"] = response.content
|
|
if response.tool_calls:
|
|
msg["tool_calls"] = [
|
|
{
|
|
"id": tc.id,
|
|
"type": "function",
|
|
"function": {"name": tc.name, "arguments": json.dumps(tc.arguments)},
|
|
}
|
|
for tc in response.tool_calls
|
|
]
|
|
return msg
|