"""Query classifier — route queries to appropriate models.""" from __future__ import annotations import re from dataclasses import dataclass from typing import Any @dataclass class ClassificationResult: """Result of classifying a query.""" category: str # "simple", "code", "research", "complex" model_hint: str # suggested model override, or "" to use default # Short, trivial queries — greetings, thanks, yes/no, etc. _SIMPLE_PATTERNS = [ re.compile(r"^(hi|hello|hey|thanks|thank you|ok|yes|no|sure|bye|good\s*(morning|evening|night))[\s!.?]*$", re.IGNORECASE), re.compile(r"^(what time|what date|what day)[\s\S]{0,20}\??\s*$", re.IGNORECASE), ] # Code-related queries _CODE_PATTERNS = [ re.compile(r"\b(function|class|def|import|const|let|var|return|async|await)\b"), re.compile(r"\b(refactor|debug|fix|implement|code|compile|build|test|lint|deploy)\b", re.IGNORECASE), re.compile(r"```"), re.compile(r"\.(py|js|ts|rs|go|java|cpp|c|rb|php|swift|kt)\b"), ] # Research / information queries _RESEARCH_PATTERNS = [ re.compile(r"\b(search|find|look\s*up|research|summarize|explain|compare|analyze)\b", re.IGNORECASE), re.compile(r"\b(what is|how does|why does|can you explain)\b", re.IGNORECASE), ] class QueryClassifier: """Classify queries for model routing. Configure with a mapping of category → model override. Categories: "simple", "code", "research", "complex". """ def __init__(self, model_map: dict[str, str] | None = None) -> None: self._model_map = model_map or {} def classify(self, query: str) -> ClassificationResult: """Classify a query and suggest a model.""" query_stripped = query.strip() # Very short queries are likely simple if len(query_stripped) < 20: for pattern in _SIMPLE_PATTERNS: if pattern.match(query_stripped): return ClassificationResult( category="simple", model_hint=self._model_map.get("simple", ""), ) # Code patterns code_score = sum(1 for p in _CODE_PATTERNS if p.search(query_stripped)) if code_score >= 2: return ClassificationResult( category="code", model_hint=self._model_map.get("code", ""), ) # Research patterns research_score = sum(1 for p in _RESEARCH_PATTERNS if p.search(query_stripped)) if research_score >= 1: return ClassificationResult( category="research", model_hint=self._model_map.get("research", ""), ) # Long or complex queries if len(query_stripped) > 500 or query_stripped.count("\n") > 5: return ClassificationResult( category="complex", model_hint=self._model_map.get("complex", ""), ) # Default — no override return ClassificationResult(category="complex", model_hint="")