diff --git a/app/features/agents/agents/base.py b/app/features/agents/agents/base.py index c69b944e..4ccd86b3 100644 --- a/app/features/agents/agents/base.py +++ b/app/features/agents/agents/base.py @@ -14,7 +14,7 @@ import httpx import structlog -from pydantic_ai import ModelRetry +from pydantic_ai import Agent, ModelRetry from pydantic_ai.models import Model from pydantic_ai.models.fallback import FallbackModel from pydantic_ai.models.openai import OpenAIChatModel @@ -248,6 +248,40 @@ def build_agent_model_with_fallback() -> Model | str: return FallbackModel(primary, fallback) +FINALIZER_SYSTEM_PROMPT = """You are a concise analyst for ForecastLabAI. +Answer the user's question using ONLY the provided tool data. Be specific and brief +(2-4 sentences, plain text — no JSON, no preamble). +- If the user asked for a ranking (lowest/highest WAPE, MAE, RMSE, …), name the + specific run/item and its value, and ignore entries whose metric is missing. +- If the data is empty, say so plainly. +- Never invent values, run ids, or entities that are not present in the data. +""" + + +def build_finalizer_agent() -> Agent[None, str]: + """Build a tool-less, plain-text agent that salvages an answer from tool data. + + Weak local models (e.g. ``ollama:llama3.1:8b``) reliably call tools and obtain + the data, but cannot wrap the result in the primary agent's structured + ``PromptedOutput`` schema — they echo the raw tool output and exhaust the + output-retry budget (issue #351). This finalizer takes the data already + obtained and answers in plain text, which weak models *can* do. It has NO + tools (cannot loop) and ``output_type=str`` (cannot fail schema validation), + so it degrades gracefully. Cloud models never need it — it only runs on the + primary agent's misbehavior path. + + Returns: + A configured plain-text :class:`Agent`, primary+fallback model wrapped. + """ + model = build_agent_model_with_fallback() + return Agent( + model=model, + output_type=str, + system_prompt=FINALIZER_SYSTEM_PROMPT, + **get_model_settings(), + ) + + def get_agent_retries() -> int: """Get the configured retry budget for agent tool calls and output validation. diff --git a/app/features/agents/service.py b/app/features/agents/service.py index 8009ea08..6372fd9c 100644 --- a/app/features/agents/service.py +++ b/app/features/agents/service.py @@ -13,6 +13,7 @@ from __future__ import annotations import asyncio +import json import uuid from collections.abc import AsyncIterator from contextlib import AbstractContextManager @@ -20,9 +21,9 @@ from typing import Any, Literal, cast import structlog -from pydantic_ai import Agent +from pydantic_ai import Agent, capture_run_messages from pydantic_ai.exceptions import UnexpectedModelBehavior -from pydantic_ai.messages import ModelMessage, ModelMessagesTypeAdapter +from pydantic_ai.messages import ModelMessage, ModelMessagesTypeAdapter, ToolReturnPart from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -40,6 +41,31 @@ logger = structlog.get_logger() +# Cap on the tool-data JSON fed to the plain-text finalizer (#351). With the +# verbose keys below stripped, a full runs page fits comfortably; the cap is a +# context-budget backstop for pathological payloads. +_FINALIZER_MAX_CHARS = 8000 + +# Verbose, decision-irrelevant keys stripped from tool results before they are +# handed to the finalizer (#351). Dropping these keeps every run's identity + +# metrics (e.g. WAPE) inside the budget, so a ranking question sees ALL runs +# instead of just the first one or two — the bug where the finalizer reported +# 99.0 as "lowest" while the true minimum (18.93) had been truncated away. +_FINALIZER_DROP_KEYS = frozenset( + { + "model_config", + "model_config_data", + "feature_config", + "runtime_info", + "agent_context", + "config_hash", + "artifact_hash", + "artifact_uri", + "artifact_size_bytes", + "error_message", + } +) + class SessionNotFoundError(ValueError): """Session not found in the database.""" @@ -266,16 +292,20 @@ async def chat( history_length=len(message_history), ) + # Always bound for the misbehavior handler, even if the run raises before + # capture_run_messages() populates it. + captured_messages: list[ModelMessage] = [] try: - with _sequential_tool_execution(): - result = await asyncio.wait_for( - agent.run( - message, - deps=deps, - message_history=message_history, - ), - timeout=self.settings.agent_timeout_seconds, - ) + with capture_run_messages() as captured_messages: + with _sequential_tool_execution(): + result = await asyncio.wait_for( + agent.run( + message, + deps=deps, + message_history=message_history, + ), + timeout=self.settings.agent_timeout_seconds, + ) except TimeoutError as e: raise TimeoutError( f"Agent response timed out after {self.settings.agent_timeout_seconds} seconds" @@ -307,6 +337,13 @@ async def chat( pending_approval=True, pending_action=salvaged, ) + # A weak local model often calls tools and obtains the data, then + # fails to wrap it in the structured output schema (#351). Salvage a + # plain-text answer from the tool data already captured this run. + answer = await self._salvage_plaintext_answer(message, captured_messages) + if answer is not None: + logger.info("agents.chat_finalizer_salvage", session_id=session_id) + return ChatResponse(session_id=session_id, message=answer) return ChatResponse( session_id=session_id, message=( @@ -483,8 +520,10 @@ async def stream_chat( default_model = self.settings.agent_default_model provider = default_model.split(":", 1)[0] if ":" in default_model else "" stream_supported = provider != "ollama" + # Always bound for the misbehavior handler (see chat()). + captured_messages: list[ModelMessage] = [] try: - with _sequential_tool_execution(): + with capture_run_messages() as captured_messages, _sequential_tool_execution(): async with asyncio.timeout(self.settings.agent_timeout_seconds): final_result: Any usage: Any @@ -694,6 +733,29 @@ async def stream_chat( timestamp=misbehavior_now, ) return + # A weak local model often calls tools and obtains the data, then + # fails to wrap it in the structured output schema (#351). Salvage a + # plain-text answer from the tool data already captured this run and + # emit it as a normal reply rather than an error. + answer = await self._salvage_plaintext_answer(message, captured_messages) + if answer is not None: + logger.info("agents.stream_chat_finalizer_salvage", session_id=session_id) + yield StreamEvent( + event_type="text_delta", + data={"delta": answer}, + timestamp=misbehavior_now, + ) + yield StreamEvent( + event_type="complete", + data={ + "message": answer, + "tokens_used": 0, + "tool_calls_count": deps.tool_call_count, + "pending_approval": False, + }, + timestamp=misbehavior_now, + ) + return yield StreamEvent( event_type="error", data={ @@ -941,6 +1003,105 @@ def _salvage_pending_action( now=now, ) + @staticmethod + def _extract_tool_payloads(captured: list[ModelMessage]) -> list[dict[str, Any]]: + """Pull every tool return out of a captured run's message trace. + + Used by :meth:`_salvage_plaintext_answer` to recover the data a weak + model fetched before it failed structured-output validation (#351). + + Args: + captured: Messages captured via ``capture_run_messages`` (may be empty + when the run failed before any tool returned). + + Returns: + One ``{"tool", "result"}`` dict per ``ToolReturnPart``, in order. + """ + payloads: list[dict[str, Any]] = [] + for message in captured: + for part in getattr(message, "parts", []): + if isinstance(part, ToolReturnPart): + payloads.append({"tool": part.tool_name, "result": part.content}) + return payloads + + @classmethod + def _compact_for_finalizer(cls, obj: object) -> object: + """Recursively strip verbose, decision-irrelevant keys from tool data (#351). + + Keeps each entry's identity + metrics while dropping bulky nested config + / runtime blobs, so a full result set fits in the finalizer's budget and + a ranking sees every entry. Pure/serialisation-only — no I/O. + + Args: + obj: Any JSON-ish value extracted from a tool return. + + Returns: + The same structure with :data:`_FINALIZER_DROP_KEYS` removed at every + dict level. + """ + if isinstance(obj, dict): + return { + k: cls._compact_for_finalizer(v) + for k, v in obj.items() + if k not in _FINALIZER_DROP_KEYS + } + if isinstance(obj, list): + return [cls._compact_for_finalizer(v) for v in obj] + return obj + + async def _salvage_plaintext_answer( + self, + message: str, + captured: list[ModelMessage], + ) -> str | None: + """Answer in plain text from tool data when structured output failed (#351). + + A weak local model (e.g. ``ollama:llama3.1:8b``) reliably calls the read + tool and gets the data, but echoes the raw tool result instead of the + primary agent's ``PromptedOutput`` schema, exhausting the output-retry + budget. The data was obtained, though — so hand it to a tool-less, + ``str``-output finalizer that answers the user's question directly. The + finalizer cannot loop (no tools) or fail schema validation (plain text). + + Args: + message: The original user message. + captured: Messages captured from the failed run. + + Returns: + The finalizer's plain-text answer, or ``None`` when no tool data was + obtained or the finalizer itself errors (caller falls back to the + generic recoverable error). + """ + payloads = self._extract_tool_payloads(captured) + if not payloads: + return None + try: + from app.features.agents.agents.base import build_finalizer_agent + + compact = self._compact_for_finalizer(payloads) + data = json.dumps(compact, default=str)[:_FINALIZER_MAX_CHARS] + prompt = ( + f"User question:\n{message}\n\n" + f"Data retrieved from tools (JSON):\n{data}\n\n" + "Answer the user's question concisely from this data. If the " + "question asks for the lowest/highest of a metric (e.g. WAPE), " + "compare that metric across ALL entries that have it, ignore " + "entries where it is missing/null, and report the true " + "minimum/maximum with its value." + ) + finalizer = build_finalizer_agent() + result = await asyncio.wait_for( + finalizer.run(prompt), + timeout=self.settings.agent_timeout_seconds, + ) + text = str(result.output).strip() + return text or None + except Exception: + # Best-effort: a finalizer failure must never replace the original + # recoverable error with a crash. + logger.warning("agents.finalizer_fallback_failed", exc_info=True) + return None + def _record_pending_action( self, session: AgentSession, diff --git a/app/features/agents/tests/test_service.py b/app/features/agents/tests/test_service.py index 888260ec..09413aa6 100644 --- a/app/features/agents/tests/test_service.py +++ b/app/features/agents/tests/test_service.py @@ -14,6 +14,7 @@ ModelRequest, ModelResponse, TextPart, + ToolReturnPart, UserPromptPart, ) @@ -335,6 +336,49 @@ async def test_chat_model_misbehavior_returns_friendly_message( assert "invalid tool call" in response.message assert "exceeded max retries" not in response.message + @pytest.mark.asyncio + async def test_chat_finalizer_salvages_answer_on_misbehavior( + self, + sample_active_session: AgentSession, + ) -> None: + """When tools fetched data but structured output failed, salvage a reply (#351). + + A weak local model calls the read tool and gets the data, then can't wrap + it in the ExperimentReport schema and exhausts the output-retry budget. + The service then asks a tool-less finalizer to answer in plain text — the + user gets the answer instead of the generic "invalid tool call" error. + """ + service = AgentService() + mock_db = AsyncMock() + + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = sample_active_session + mock_db.execute.return_value = mock_result + + mock_agent = MagicMock() + mock_agent.run = AsyncMock( + side_effect=UnexpectedModelBehavior("Exceeded maximum output retries (3)") + ) + + salvaged_answer = "The lowest WAPE is the naive run 2fad611b (18.93)." + with ( + patch.object(service, "_get_agent", return_value=mock_agent), + patch.object( + service, + "_salvage_plaintext_answer", + AsyncMock(return_value=salvaged_answer), + ), + ): + response = await service.chat( + db=mock_db, + session_id=sample_active_session.session_id, + message="List the most recent model runs and tell me which has the lowest WAPE.", + ) + + assert response.message == salvaged_answer + assert response.pending_approval is False + assert "invalid tool call" not in response.message + @pytest.mark.asyncio async def test_chat_runs_tools_sequentially( self, @@ -1143,3 +1187,89 @@ def _run_stream(message: str, *, deps: AgentDeps, message_history: Any) -> _Stub assert approval_events[0].data["action"].action_type == "save_scenario" assert sample_active_session.status == SessionStatus.AWAITING_APPROVAL.value assert sample_active_session.pending_action is not None + + +class TestFinalizerSalvage: + """The plain-text finalizer fallback used on structured-output failure (#351).""" + + def test_extract_tool_payloads_pulls_tool_returns(self) -> None: + """Tool returns are extracted from a captured run trace, in order.""" + captured: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content="List runs")]), + ModelResponse(parts=[TextPart(content="{}")]), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name="tool_list_runs", + content={"runs": [{"run_id": "abc", "wape": 18.93}]}, + tool_call_id="call-1", + ) + ] + ), + ] + + payloads = AgentService._extract_tool_payloads(captured) + + assert payloads == [ + {"tool": "tool_list_runs", "result": {"runs": [{"run_id": "abc", "wape": 18.93}]}} + ] + + def test_extract_tool_payloads_empty_when_no_tool_returns(self) -> None: + """No tool returns (model failed before any tool ran) yields an empty list.""" + captured: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content="List runs")]), + ModelResponse(parts=[TextPart(content='{"runs": []}')]), + ] + + assert AgentService._extract_tool_payloads(captured) == [] + + @pytest.mark.asyncio + async def test_salvage_returns_none_without_tool_data(self) -> None: + """With no captured tool data, salvage returns None (caller emits the error).""" + service = AgentService() + result = await service._salvage_plaintext_answer("any question", []) + assert result is None + + def test_compact_for_finalizer_strips_verbose_keys_keeps_metrics(self) -> None: + """Compaction drops bulky config/runtime blobs but keeps identity + metrics (#351). + + Regression for the finalizer reporting 99.0 as "lowest WAPE" when the + true minimum (18.93) had been truncated out of the oversized payload. + """ + raw = [ + { + "tool": "tool_list_runs", + "result": { + "runs": [ + { + "run_id": "a", + "model_type": "seasonal_naive", + "metrics": {"wape": 99.0}, + "model_config_data": {"x": "y" * 500}, + "runtime_info": {"python": "3.12"}, + "artifact_uri": "demo/seasonal-model_a.joblib", + }, + { + "run_id": "b", + "model_type": "naive", + "metrics": {"wape": 18.93}, + "feature_config": {"lots": "of stuff"}, + }, + ] + }, + } + ] + + compact = AgentService._compact_for_finalizer(raw) + runs = compact[0]["result"]["runs"] + + # Identity + metrics survive for BOTH runs (so a ranking sees 18.93). + assert runs[0]["run_id"] == "a" + assert runs[0]["metrics"] == {"wape": 99.0} + assert runs[1]["run_id"] == "b" + assert runs[1]["metrics"] == {"wape": 18.93} + # Verbose blobs are gone. + assert "model_config_data" not in runs[0] + assert "runtime_info" not in runs[0] + assert "artifact_uri" not in runs[0] + assert "feature_config" not in runs[1]