Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions app/features/agents/agents/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,13 @@ async def tool_create_alias(

# Check if approval is required
if requires_approval("create_alias"):
# Record a machine-readable approval request so the service layer
# can persist pending_action + emit approval_required (#336).
ctx.deps.set_pending_action(
"create_alias",
{"alias_name": alias_name, "run_id": run_id, "description": description},
f"Create alias '{alias_name}' pointing at run {run_id}",
)
return {
"status": "approval_required",
"action": "create_alias",
Expand Down Expand Up @@ -366,6 +373,13 @@ async def tool_archive_run(

# Check if approval is required
if requires_approval("archive_run"):
# Record a machine-readable approval request so the service layer
# can persist pending_action + emit approval_required (#336).
ctx.deps.set_pending_action(
"archive_run",
{"run_id": run_id},
f"Archive run {run_id}",
)
return {
"status": "approval_required",
"action": "archive_run",
Expand Down Expand Up @@ -466,6 +480,14 @@ async def tool_save_scenario(

# Check if approval is required — mirrors tool_create_alias exactly.
if requires_approval("save_scenario"):
# Record a machine-readable approval request so the service layer
# can persist pending_action + emit approval_required (#336). The
# arguments dict is exactly what _execute_pending_action replays.
ctx.deps.set_pending_action(
"save_scenario",
arguments,
f"Save scenario plan '{name}' for store {store_id} / product {product_id}",
)
return {
"status": "approval_required",
"action": "save_scenario",
Expand Down
33 changes: 33 additions & 0 deletions app/features/agents/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any

from sqlalchemy.ext.asyncio import AsyncSession

Expand All @@ -23,14 +24,46 @@ class AgentDeps:
session_id: Current agent session ID.
request_id: Optional request correlation ID for logging.
tool_call_count: Counter for tool calls in this run.
pending_action: Machine-readable HITL approval request recorded by a
gated tool when it short-circuits without persisting (#336). The
service layer reads this after the agent run to flip the session to
``awaiting_approval`` and emit the ``approval_required`` event,
instead of relying on the model echoing the request into its
structured output.
"""

db: AsyncSession
session_id: str
request_id: str | None = None
tool_call_count: int = field(default=0)
pending_action: dict[str, Any] | None = field(default=None)

def increment_tool_calls(self) -> int:
"""Increment and return the tool call count."""
self.tool_call_count += 1
return self.tool_call_count

def set_pending_action(
self,
action_type: str,
arguments: dict[str, Any],
description: str,
) -> None:
"""Record that a gated tool call needs human approval (HITL).

Called by approval-gated tools (e.g. ``save_scenario``, ``create_alias``,
``archive_run``) instead of persisting their effect. The ``arguments``
dict must carry everything ``AgentService._execute_pending_action`` needs
to run the action once a human approves it.

Args:
action_type: The gated action name (``create_alias`` / ``archive_run``
/ ``save_scenario``).
arguments: Arguments to replay when the action is approved.
description: Human-readable summary shown on the approval card.
"""
self.pending_action = {
"action_type": action_type,
"arguments": arguments,
"description": description,
}
158 changes: 100 additions & 58 deletions app/features/agents/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,26 @@ async def chat(
# NOTE: PydanticAI v1.48.0 uses result.output (not result.data)
result_data: Any = result.output

# Check for pending_action in result data (primary trigger)
# Primary trigger (#336): a gated tool recorded a machine-readable
# approval request on deps. Deterministic — does not rely on the model
# echoing the request into its structured output (ExperimentReport has
# no pending_action field, so the legacy triggers below never fired).
if deps.pending_action:
pending_approval = True
pending_action = self._record_pending_action(
session,
action_type=str(deps.pending_action.get("action_type", "unknown")),
arguments=deps.pending_action.get("arguments") or {},
description=str(
deps.pending_action.get("description")
or f"Agent requested approval for "
f"{deps.pending_action.get('action_type', 'unknown')}"
),
now=now,
)
# Legacy trigger: structured output carried a pending_action field.
# The agent tools should return a pending_action dict with action_type and arguments
if hasattr(result_data, "pending_action") and result_data.pending_action:
elif hasattr(result_data, "pending_action") and result_data.pending_action:
pending_approval = True
pending_action_data = result_data.pending_action
# Extract action details - support both dict and object with attributes
Expand All @@ -335,33 +352,19 @@ async def chat(
f"Agent requested approval for {action_type}",
)

session.pending_action = {
"action_id": uuid.uuid4().hex[:16],
"action_type": action_type,
"description": description,
"arguments": arguments,
"created_at": now.isoformat(),
"expires_at": (
now + timedelta(minutes=self.settings.agent_approval_timeout_minutes)
).isoformat(),
}
session.status = SessionStatus.AWAITING_APPROVAL.value
pending_action = self._format_pending_action(session.pending_action)
pending_action = self._record_pending_action(
session, action_type, arguments, description, now
)
# Fallback: check approval_required flag (legacy trigger)
elif hasattr(result_data, "approval_required") and result_data.approval_required:
pending_approval = True
session.pending_action = {
"action_id": uuid.uuid4().hex[:16],
"action_type": "unknown",
"description": "Agent requested approval for an action",
"arguments": {},
"created_at": now.isoformat(),
"expires_at": (
now + timedelta(minutes=self.settings.agent_approval_timeout_minutes)
).isoformat(),
}
session.status = SessionStatus.AWAITING_APPROVAL.value
pending_action = self._format_pending_action(session.pending_action)
pending_action = self._record_pending_action(
session,
"unknown",
{},
"Agent requested approval for an action",
now,
)

# Update session
usage = result.usage()
Expand Down Expand Up @@ -502,8 +505,28 @@ async def stream_chat(
pending_approval = False
stream_now = datetime.now(UTC)

# Check for pending_action in result data (primary trigger)
if hasattr(final_result, "pending_action") and final_result.pending_action:
# Primary trigger (#336): a gated tool recorded a
# machine-readable approval request on deps. Deterministic
# — the experiment agent's ExperimentReport output has no
# pending_action field, so the legacy triggers below never
# fired and the approval_required event was never emitted.
if deps.pending_action:
pending_approval = True
pending_action = self._record_pending_action(
session,
action_type=str(deps.pending_action.get("action_type", "unknown")),
arguments=deps.pending_action.get("arguments") or {},
description=str(
deps.pending_action.get("description")
or "Agent requested approval for "
f"{deps.pending_action.get('action_type', 'unknown')}"
),
now=stream_now,
)
# Legacy trigger: structured output carried pending_action.
elif (
hasattr(final_result, "pending_action") and final_result.pending_action
):
pending_approval = True
pending_action_data = final_result.pending_action
# Extract action details - support both dict and object with attributes
Expand All @@ -522,42 +545,22 @@ async def stream_chat(
f"Agent requested approval for {action_type}",
)

session.pending_action = {
"action_id": uuid.uuid4().hex[:16],
"action_type": action_type,
"description": description,
"arguments": arguments,
"created_at": stream_now.isoformat(),
"expires_at": (
stream_now
+ timedelta(
minutes=self.settings.agent_approval_timeout_minutes
)
).isoformat(),
}
session.status = SessionStatus.AWAITING_APPROVAL.value
pending_action = self._format_pending_action(session.pending_action)
pending_action = self._record_pending_action(
session, action_type, arguments, description, stream_now
)
# Fallback: check approval_required flag (legacy trigger)
elif (
hasattr(final_result, "approval_required")
and final_result.approval_required
):
pending_approval = True
session.pending_action = {
"action_id": uuid.uuid4().hex[:16],
"action_type": "unknown",
"description": "Agent requested approval for an action",
"arguments": {},
"created_at": stream_now.isoformat(),
"expires_at": (
stream_now
+ timedelta(
minutes=self.settings.agent_approval_timeout_minutes
)
).isoformat(),
}
session.status = SessionStatus.AWAITING_APPROVAL.value
pending_action = self._format_pending_action(session.pending_action)
pending_action = self._record_pending_action(
session,
"unknown",
{},
"Agent requested approval for an action",
stream_now,
)

await db.flush()

Expand Down Expand Up @@ -825,6 +828,45 @@ def _deserialize_messages(
)
return []

def _record_pending_action(
self,
session: AgentSession,
action_type: str,
arguments: dict[str, Any],
description: str,
now: datetime,
) -> PendingAction | None:
"""Persist a HITL approval request on the session and format it.

Builds the canonical ``session.pending_action`` dict (fresh action_id +
expiry), flips the session to ``awaiting_approval``, and returns the
``PendingAction`` schema for the response / stream event. Shared by the
deterministic deps-based trigger (#336) and the legacy structured-output
triggers so all three paths persist an identical shape.

Args:
session: The agent session to mutate.
action_type: Gated action name.
arguments: Arguments to replay on approval.
description: Human-readable approval-card summary.
now: Timestamp used for created_at / expires_at.

Returns:
The formatted PendingAction, or None if formatting fails.
"""
session.pending_action = {
"action_id": uuid.uuid4().hex[:16],
"action_type": action_type,
"description": description,
"arguments": arguments,
"created_at": now.isoformat(),
"expires_at": (
now + timedelta(minutes=self.settings.agent_approval_timeout_minutes)
).isoformat(),
}
session.status = SessionStatus.AWAITING_APPROVAL.value
return self._format_pending_action(session.pending_action)

def _format_pending_action(
self,
pending: dict[str, Any] | None,
Expand Down
Loading