Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion app/features/agents/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import httpx
import structlog
from pydantic_ai import ModelRetry
from pydantic_ai import Agent, ModelRetry
from pydantic_ai.models import Model
from pydantic_ai.models.fallback import FallbackModel
from pydantic_ai.models.openai import OpenAIChatModel
Expand Down Expand Up @@ -248,6 +248,40 @@ def build_agent_model_with_fallback() -> Model | str:
return FallbackModel(primary, fallback)


FINALIZER_SYSTEM_PROMPT = """You are a concise analyst for ForecastLabAI.
Answer the user's question using ONLY the provided tool data. Be specific and brief
(2-4 sentences, plain text — no JSON, no preamble).
- If the user asked for a ranking (lowest/highest WAPE, MAE, RMSE, …), name the
specific run/item and its value, and ignore entries whose metric is missing.
- If the data is empty, say so plainly.
- Never invent values, run ids, or entities that are not present in the data.
"""


def build_finalizer_agent() -> Agent[None, str]:
"""Build a tool-less, plain-text agent that salvages an answer from tool data.

Weak local models (e.g. ``ollama:llama3.1:8b``) reliably call tools and obtain
the data, but cannot wrap the result in the primary agent's structured
``PromptedOutput`` schema — they echo the raw tool output and exhaust the
output-retry budget (issue #351). This finalizer takes the data already
obtained and answers in plain text, which weak models *can* do. It has NO
tools (cannot loop) and ``output_type=str`` (cannot fail schema validation),
so it degrades gracefully. Cloud models never need it — it only runs on the
primary agent's misbehavior path.

Returns:
A configured plain-text :class:`Agent`, primary+fallback model wrapped.
"""
model = build_agent_model_with_fallback()
return Agent(
model=model,
output_type=str,
system_prompt=FINALIZER_SYSTEM_PROMPT,
**get_model_settings(),
)


def get_agent_retries() -> int:
"""Get the configured retry budget for agent tool calls and output validation.

Expand Down
185 changes: 173 additions & 12 deletions app/features/agents/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@
from __future__ import annotations

import asyncio
import json
import uuid
from collections.abc import AsyncIterator
from contextlib import AbstractContextManager
from datetime import UTC, datetime, timedelta
from typing import Any, Literal, cast

import structlog
from pydantic_ai import Agent
from pydantic_ai import Agent, capture_run_messages
from pydantic_ai.exceptions import UnexpectedModelBehavior
from pydantic_ai.messages import ModelMessage, ModelMessagesTypeAdapter
from pydantic_ai.messages import ModelMessage, ModelMessagesTypeAdapter, ToolReturnPart
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

Expand All @@ -40,6 +41,31 @@

logger = structlog.get_logger()

# Cap on the tool-data JSON fed to the plain-text finalizer (#351). With the
# verbose keys below stripped, a full runs page fits comfortably; the cap is a
# context-budget backstop for pathological payloads.
_FINALIZER_MAX_CHARS = 8000

# Verbose, decision-irrelevant keys stripped from tool results before they are
# handed to the finalizer (#351). Dropping these keeps every run's identity +
# metrics (e.g. WAPE) inside the budget, so a ranking question sees ALL runs
# instead of just the first one or two — the bug where the finalizer reported
# 99.0 as "lowest" while the true minimum (18.93) had been truncated away.
_FINALIZER_DROP_KEYS = frozenset(
{
"model_config",
"model_config_data",
"feature_config",
"runtime_info",
"agent_context",
"config_hash",
"artifact_hash",
"artifact_uri",
"artifact_size_bytes",
"error_message",
}
)


class SessionNotFoundError(ValueError):
"""Session not found in the database."""
Expand Down Expand Up @@ -266,16 +292,20 @@ async def chat(
history_length=len(message_history),
)

# Always bound for the misbehavior handler, even if the run raises before
# capture_run_messages() populates it.
captured_messages: list[ModelMessage] = []
try:
with _sequential_tool_execution():
result = await asyncio.wait_for(
agent.run(
message,
deps=deps,
message_history=message_history,
),
timeout=self.settings.agent_timeout_seconds,
)
with capture_run_messages() as captured_messages:
with _sequential_tool_execution():
result = await asyncio.wait_for(
agent.run(
message,
deps=deps,
message_history=message_history,
),
timeout=self.settings.agent_timeout_seconds,
)
except TimeoutError as e:
raise TimeoutError(
f"Agent response timed out after {self.settings.agent_timeout_seconds} seconds"
Expand Down Expand Up @@ -307,6 +337,13 @@ async def chat(
pending_approval=True,
pending_action=salvaged,
)
# A weak local model often calls tools and obtains the data, then
# fails to wrap it in the structured output schema (#351). Salvage a
# plain-text answer from the tool data already captured this run.
answer = await self._salvage_plaintext_answer(message, captured_messages)
if answer is not None:
logger.info("agents.chat_finalizer_salvage", session_id=session_id)
return ChatResponse(session_id=session_id, message=answer)
return ChatResponse(
session_id=session_id,
message=(
Expand Down Expand Up @@ -483,8 +520,10 @@ async def stream_chat(
default_model = self.settings.agent_default_model
provider = default_model.split(":", 1)[0] if ":" in default_model else ""
stream_supported = provider != "ollama"
# Always bound for the misbehavior handler (see chat()).
captured_messages: list[ModelMessage] = []
try:
with _sequential_tool_execution():
with capture_run_messages() as captured_messages, _sequential_tool_execution():
async with asyncio.timeout(self.settings.agent_timeout_seconds):
final_result: Any
usage: Any
Expand Down Expand Up @@ -694,6 +733,29 @@ async def stream_chat(
timestamp=misbehavior_now,
)
return
# A weak local model often calls tools and obtains the data, then
# fails to wrap it in the structured output schema (#351). Salvage a
# plain-text answer from the tool data already captured this run and
# emit it as a normal reply rather than an error.
answer = await self._salvage_plaintext_answer(message, captured_messages)
if answer is not None:
logger.info("agents.stream_chat_finalizer_salvage", session_id=session_id)
yield StreamEvent(
event_type="text_delta",
data={"delta": answer},
timestamp=misbehavior_now,
)
yield StreamEvent(
event_type="complete",
data={
"message": answer,
"tokens_used": 0,
"tool_calls_count": deps.tool_call_count,
"pending_approval": False,
},
timestamp=misbehavior_now,
)
return
yield StreamEvent(
event_type="error",
data={
Expand Down Expand Up @@ -941,6 +1003,105 @@ def _salvage_pending_action(
now=now,
)

@staticmethod
def _extract_tool_payloads(captured: list[ModelMessage]) -> list[dict[str, Any]]:
"""Pull every tool return out of a captured run's message trace.

Used by :meth:`_salvage_plaintext_answer` to recover the data a weak
model fetched before it failed structured-output validation (#351).

Args:
captured: Messages captured via ``capture_run_messages`` (may be empty
when the run failed before any tool returned).

Returns:
One ``{"tool", "result"}`` dict per ``ToolReturnPart``, in order.
"""
payloads: list[dict[str, Any]] = []
for message in captured:
for part in getattr(message, "parts", []):
if isinstance(part, ToolReturnPart):
payloads.append({"tool": part.tool_name, "result": part.content})
return payloads

@classmethod
def _compact_for_finalizer(cls, obj: object) -> object:
"""Recursively strip verbose, decision-irrelevant keys from tool data (#351).

Keeps each entry's identity + metrics while dropping bulky nested config
/ runtime blobs, so a full result set fits in the finalizer's budget and
a ranking sees every entry. Pure/serialisation-only — no I/O.

Args:
obj: Any JSON-ish value extracted from a tool return.

Returns:
The same structure with :data:`_FINALIZER_DROP_KEYS` removed at every
dict level.
"""
if isinstance(obj, dict):
return {
k: cls._compact_for_finalizer(v)
for k, v in obj.items()
if k not in _FINALIZER_DROP_KEYS
}
if isinstance(obj, list):
return [cls._compact_for_finalizer(v) for v in obj]
return obj

async def _salvage_plaintext_answer(
self,
message: str,
captured: list[ModelMessage],
) -> str | None:
"""Answer in plain text from tool data when structured output failed (#351).

A weak local model (e.g. ``ollama:llama3.1:8b``) reliably calls the read
tool and gets the data, but echoes the raw tool result instead of the
primary agent's ``PromptedOutput`` schema, exhausting the output-retry
budget. The data was obtained, though — so hand it to a tool-less,
``str``-output finalizer that answers the user's question directly. The
finalizer cannot loop (no tools) or fail schema validation (plain text).

Args:
message: The original user message.
captured: Messages captured from the failed run.

Returns:
The finalizer's plain-text answer, or ``None`` when no tool data was
obtained or the finalizer itself errors (caller falls back to the
generic recoverable error).
"""
payloads = self._extract_tool_payloads(captured)
if not payloads:
return None
try:
from app.features.agents.agents.base import build_finalizer_agent

compact = self._compact_for_finalizer(payloads)
data = json.dumps(compact, default=str)[:_FINALIZER_MAX_CHARS]
prompt = (
f"User question:\n{message}\n\n"
f"Data retrieved from tools (JSON):\n{data}\n\n"
"Answer the user's question concisely from this data. If the "
"question asks for the lowest/highest of a metric (e.g. WAPE), "
"compare that metric across ALL entries that have it, ignore "
"entries where it is missing/null, and report the true "
"minimum/maximum with its value."
)
finalizer = build_finalizer_agent()
result = await asyncio.wait_for(
finalizer.run(prompt),
timeout=self.settings.agent_timeout_seconds,
)
text = str(result.output).strip()
return text or None
except Exception:
# Best-effort: a finalizer failure must never replace the original
# recoverable error with a crash.
logger.warning("agents.finalizer_fallback_failed", exc_info=True)
return None

def _record_pending_action(
self,
session: AgentSession,
Expand Down
Loading
Loading