diff --git a/app/features/agents/agents/experiment.py b/app/features/agents/agents/experiment.py index d9bad5bd..1b0139f6 100644 --- a/app/features/agents/agents/experiment.py +++ b/app/features/agents/agents/experiment.py @@ -322,6 +322,13 @@ async def tool_create_alias( # Check if approval is required if requires_approval("create_alias"): + # Record a machine-readable approval request so the service layer + # can persist pending_action + emit approval_required (#336). + ctx.deps.set_pending_action( + "create_alias", + {"alias_name": alias_name, "run_id": run_id, "description": description}, + f"Create alias '{alias_name}' pointing at run {run_id}", + ) return { "status": "approval_required", "action": "create_alias", @@ -366,6 +373,13 @@ async def tool_archive_run( # Check if approval is required if requires_approval("archive_run"): + # Record a machine-readable approval request so the service layer + # can persist pending_action + emit approval_required (#336). + ctx.deps.set_pending_action( + "archive_run", + {"run_id": run_id}, + f"Archive run {run_id}", + ) return { "status": "approval_required", "action": "archive_run", @@ -466,6 +480,14 @@ async def tool_save_scenario( # Check if approval is required — mirrors tool_create_alias exactly. if requires_approval("save_scenario"): + # Record a machine-readable approval request so the service layer + # can persist pending_action + emit approval_required (#336). The + # arguments dict is exactly what _execute_pending_action replays. + ctx.deps.set_pending_action( + "save_scenario", + arguments, + f"Save scenario plan '{name}' for store {store_id} / product {product_id}", + ) return { "status": "approval_required", "action": "save_scenario", diff --git a/app/features/agents/deps.py b/app/features/agents/deps.py index 23bcf1f8..3344ad67 100644 --- a/app/features/agents/deps.py +++ b/app/features/agents/deps.py @@ -7,6 +7,7 @@ from __future__ import annotations from dataclasses import dataclass, field +from typing import Any from sqlalchemy.ext.asyncio import AsyncSession @@ -23,14 +24,46 @@ class AgentDeps: session_id: Current agent session ID. request_id: Optional request correlation ID for logging. tool_call_count: Counter for tool calls in this run. + pending_action: Machine-readable HITL approval request recorded by a + gated tool when it short-circuits without persisting (#336). The + service layer reads this after the agent run to flip the session to + ``awaiting_approval`` and emit the ``approval_required`` event, + instead of relying on the model echoing the request into its + structured output. """ db: AsyncSession session_id: str request_id: str | None = None tool_call_count: int = field(default=0) + pending_action: dict[str, Any] | None = field(default=None) def increment_tool_calls(self) -> int: """Increment and return the tool call count.""" self.tool_call_count += 1 return self.tool_call_count + + def set_pending_action( + self, + action_type: str, + arguments: dict[str, Any], + description: str, + ) -> None: + """Record that a gated tool call needs human approval (HITL). + + Called by approval-gated tools (e.g. ``save_scenario``, ``create_alias``, + ``archive_run``) instead of persisting their effect. The ``arguments`` + dict must carry everything ``AgentService._execute_pending_action`` needs + to run the action once a human approves it. + + Args: + action_type: The gated action name (``create_alias`` / ``archive_run`` + / ``save_scenario``). + arguments: Arguments to replay when the action is approved. + description: Human-readable summary shown on the approval card. + """ + self.pending_action = { + "action_type": action_type, + "arguments": arguments, + "description": description, + } diff --git a/app/features/agents/service.py b/app/features/agents/service.py index 1b3c4644..cdc83882 100644 --- a/app/features/agents/service.py +++ b/app/features/agents/service.py @@ -314,9 +314,26 @@ async def chat( # NOTE: PydanticAI v1.48.0 uses result.output (not result.data) result_data: Any = result.output - # Check for pending_action in result data (primary trigger) + # Primary trigger (#336): a gated tool recorded a machine-readable + # approval request on deps. Deterministic — does not rely on the model + # echoing the request into its structured output (ExperimentReport has + # no pending_action field, so the legacy triggers below never fired). + if deps.pending_action: + pending_approval = True + pending_action = self._record_pending_action( + session, + action_type=str(deps.pending_action.get("action_type", "unknown")), + arguments=deps.pending_action.get("arguments") or {}, + description=str( + deps.pending_action.get("description") + or f"Agent requested approval for " + f"{deps.pending_action.get('action_type', 'unknown')}" + ), + now=now, + ) + # Legacy trigger: structured output carried a pending_action field. # The agent tools should return a pending_action dict with action_type and arguments - if hasattr(result_data, "pending_action") and result_data.pending_action: + elif hasattr(result_data, "pending_action") and result_data.pending_action: pending_approval = True pending_action_data = result_data.pending_action # Extract action details - support both dict and object with attributes @@ -335,33 +352,19 @@ async def chat( f"Agent requested approval for {action_type}", ) - session.pending_action = { - "action_id": uuid.uuid4().hex[:16], - "action_type": action_type, - "description": description, - "arguments": arguments, - "created_at": now.isoformat(), - "expires_at": ( - now + timedelta(minutes=self.settings.agent_approval_timeout_minutes) - ).isoformat(), - } - session.status = SessionStatus.AWAITING_APPROVAL.value - pending_action = self._format_pending_action(session.pending_action) + pending_action = self._record_pending_action( + session, action_type, arguments, description, now + ) # Fallback: check approval_required flag (legacy trigger) elif hasattr(result_data, "approval_required") and result_data.approval_required: pending_approval = True - session.pending_action = { - "action_id": uuid.uuid4().hex[:16], - "action_type": "unknown", - "description": "Agent requested approval for an action", - "arguments": {}, - "created_at": now.isoformat(), - "expires_at": ( - now + timedelta(minutes=self.settings.agent_approval_timeout_minutes) - ).isoformat(), - } - session.status = SessionStatus.AWAITING_APPROVAL.value - pending_action = self._format_pending_action(session.pending_action) + pending_action = self._record_pending_action( + session, + "unknown", + {}, + "Agent requested approval for an action", + now, + ) # Update session usage = result.usage() @@ -502,8 +505,28 @@ async def stream_chat( pending_approval = False stream_now = datetime.now(UTC) - # Check for pending_action in result data (primary trigger) - if hasattr(final_result, "pending_action") and final_result.pending_action: + # Primary trigger (#336): a gated tool recorded a + # machine-readable approval request on deps. Deterministic + # — the experiment agent's ExperimentReport output has no + # pending_action field, so the legacy triggers below never + # fired and the approval_required event was never emitted. + if deps.pending_action: + pending_approval = True + pending_action = self._record_pending_action( + session, + action_type=str(deps.pending_action.get("action_type", "unknown")), + arguments=deps.pending_action.get("arguments") or {}, + description=str( + deps.pending_action.get("description") + or "Agent requested approval for " + f"{deps.pending_action.get('action_type', 'unknown')}" + ), + now=stream_now, + ) + # Legacy trigger: structured output carried pending_action. + elif ( + hasattr(final_result, "pending_action") and final_result.pending_action + ): pending_approval = True pending_action_data = final_result.pending_action # Extract action details - support both dict and object with attributes @@ -522,42 +545,22 @@ async def stream_chat( f"Agent requested approval for {action_type}", ) - session.pending_action = { - "action_id": uuid.uuid4().hex[:16], - "action_type": action_type, - "description": description, - "arguments": arguments, - "created_at": stream_now.isoformat(), - "expires_at": ( - stream_now - + timedelta( - minutes=self.settings.agent_approval_timeout_minutes - ) - ).isoformat(), - } - session.status = SessionStatus.AWAITING_APPROVAL.value - pending_action = self._format_pending_action(session.pending_action) + pending_action = self._record_pending_action( + session, action_type, arguments, description, stream_now + ) # Fallback: check approval_required flag (legacy trigger) elif ( hasattr(final_result, "approval_required") and final_result.approval_required ): pending_approval = True - session.pending_action = { - "action_id": uuid.uuid4().hex[:16], - "action_type": "unknown", - "description": "Agent requested approval for an action", - "arguments": {}, - "created_at": stream_now.isoformat(), - "expires_at": ( - stream_now - + timedelta( - minutes=self.settings.agent_approval_timeout_minutes - ) - ).isoformat(), - } - session.status = SessionStatus.AWAITING_APPROVAL.value - pending_action = self._format_pending_action(session.pending_action) + pending_action = self._record_pending_action( + session, + "unknown", + {}, + "Agent requested approval for an action", + stream_now, + ) await db.flush() @@ -825,6 +828,45 @@ def _deserialize_messages( ) return [] + def _record_pending_action( + self, + session: AgentSession, + action_type: str, + arguments: dict[str, Any], + description: str, + now: datetime, + ) -> PendingAction | None: + """Persist a HITL approval request on the session and format it. + + Builds the canonical ``session.pending_action`` dict (fresh action_id + + expiry), flips the session to ``awaiting_approval``, and returns the + ``PendingAction`` schema for the response / stream event. Shared by the + deterministic deps-based trigger (#336) and the legacy structured-output + triggers so all three paths persist an identical shape. + + Args: + session: The agent session to mutate. + action_type: Gated action name. + arguments: Arguments to replay on approval. + description: Human-readable approval-card summary. + now: Timestamp used for created_at / expires_at. + + Returns: + The formatted PendingAction, or None if formatting fails. + """ + session.pending_action = { + "action_id": uuid.uuid4().hex[:16], + "action_type": action_type, + "description": description, + "arguments": arguments, + "created_at": now.isoformat(), + "expires_at": ( + now + timedelta(minutes=self.settings.agent_approval_timeout_minutes) + ).isoformat(), + } + session.status = SessionStatus.AWAITING_APPROVAL.value + return self._format_pending_action(session.pending_action) + def _format_pending_action( self, pending: dict[str, Any] | None, diff --git a/app/features/agents/tests/test_service.py b/app/features/agents/tests/test_service.py index 08064495..47b90c31 100644 --- a/app/features/agents/tests/test_service.py +++ b/app/features/agents/tests/test_service.py @@ -777,3 +777,150 @@ def test_increment_tool_calls(self, mock_db_session: AsyncMock) -> None: assert deps.tool_call_count == 1 deps.increment_tool_calls() assert deps.tool_call_count == 2 + + def test_set_pending_action_records_request(self, mock_db_session: AsyncMock) -> None: + """set_pending_action should record a machine-readable HITL request (#336).""" + deps = AgentDeps(db=mock_db_session, session_id="test-123") + assert deps.pending_action is None + + deps.set_pending_action( + "save_scenario", + {"name": "p", "run_id": "r", "store_id": 1, "product_id": 2}, + "Save scenario plan 'p'", + ) + + assert deps.pending_action is not None + assert deps.pending_action["action_type"] == "save_scenario" + assert deps.pending_action["arguments"]["run_id"] == "r" + assert deps.pending_action["description"] == "Save scenario plan 'p'" + + +class TestAgentServiceDepsApproval: + """Regression tests for #336 — gated tools propagate approval via deps. + + The experiment agent's structured output (ExperimentReport) carries no + pending_action/approval_required field, so a gated tool call (e.g. + save_scenario) used to leave the session ``active`` with no pending action + and no ``approval_required`` event. These assert the deterministic + deps-based path: tool -> deps.pending_action -> awaiting_approval -> + approval_required. + """ + + @staticmethod + def _save_scenario_pending(deps: AgentDeps) -> None: + """Simulate the gated save_scenario tool short-circuiting for approval.""" + deps.set_pending_action( + "save_scenario", + { + "name": "plan-a", + "run_id": "702c7ce74e9848d3b11f124a71bf7b50", + "store_id": 111, + "product_id": 339, + "horizon": 14, + "assumptions": {}, + "source": "agent", + "agent_session_id": deps.session_id, + }, + "Save scenario plan 'plan-a' for store 111 / product 339", + ) + + @pytest.mark.asyncio + async def test_chat_persists_pending_action_from_deps( + self, + sample_active_session: AgentSession, + sample_experiment_report: ExperimentReport, + ) -> None: + """chat() must persist deps.pending_action even when the output lacks one.""" + service = AgentService() + mock_db = AsyncMock() + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = sample_active_session + mock_db.execute.return_value = mock_result + + def _run(message: str, *, deps: AgentDeps, message_history: Any) -> MagicMock: + # A gated tool fired during the run and recorded the approval request. + self._save_scenario_pending(deps) + res = MagicMock() + res.output = sample_experiment_report # no pending_action field + usage = MagicMock() + usage.total_tokens = 7 + res.usage.return_value = usage + res.all_messages.return_value = [] + return res + + mock_agent = MagicMock() + mock_agent.run = AsyncMock(side_effect=_run) + + with patch.object(service, "_get_agent", return_value=mock_agent): + response = await service.chat( + db=mock_db, + session_id=sample_active_session.session_id, + message="Save a what-if scenario plan for run 702c...", + ) + + assert response.pending_approval is True + assert response.pending_action is not None + assert response.pending_action.action_type == "save_scenario" + assert response.pending_action.arguments["run_id"] == "702c7ce74e9848d3b11f124a71bf7b50" + assert sample_active_session.status == SessionStatus.AWAITING_APPROVAL.value + assert sample_active_session.pending_action is not None + assert sample_active_session.pending_action["action_type"] == "save_scenario" + + @pytest.mark.asyncio + async def test_stream_chat_emits_approval_required_from_deps( + self, + sample_active_session: AgentSession, + sample_experiment_report: ExperimentReport, + ) -> None: + """stream_chat() must emit approval_required from deps.pending_action.""" + service = AgentService() + mock_db = AsyncMock() + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = sample_active_session + mock_db.execute.return_value = mock_result + + report = sample_experiment_report + + class _StubStream: + async def __aenter__(self) -> MagicMock: + stream = MagicMock() + + async def _stream_text() -> AsyncIterator[str]: + # Structured-output agents cannot stream text deltas; mirror + # that by yielding nothing. + return + yield # pragma: no cover + + stream.stream_text = _stream_text + stream.get_output = AsyncMock(return_value=report) + usage = MagicMock() + usage.total_tokens = 9 + stream.usage.return_value = usage + stream.all_messages.return_value = [] + return stream + + async def __aexit__(self, *exc: object) -> bool: + return False + + def _run_stream(message: str, *, deps: AgentDeps, message_history: Any) -> _StubStream: + self._save_scenario_pending(deps) + return _StubStream() + + mock_agent = MagicMock() + mock_agent.run_stream = MagicMock(side_effect=_run_stream) + + with patch.object(service, "_get_agent", return_value=mock_agent): + events = [ + event + async for event in service.stream_chat( + db=mock_db, + session_id=sample_active_session.session_id, + message="Save a what-if scenario plan for run 702c...", + ) + ] + + approval_events = [e for e in events if e.event_type == "approval_required"] + assert len(approval_events) == 1 + assert approval_events[0].data["action"].action_type == "save_scenario" + assert sample_active_session.status == SessionStatus.AWAITING_APPROVAL.value + assert sample_active_session.pending_action is not None