From 555c640dd6bbeb87c23f868ed219f8942d0610f3 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 26 Jun 2026 11:26:17 +0200 Subject: [PATCH 01/10] Drive resolver elicitation over the 2026-07-28 input_required flow Resolvers that return Elicit[T] now negotiate the transport by protocol version: at >= 2026-07-28 the framework returns an InputRequiredResult carrying the batched questions and resumes when the client retries with input_responses/request_state; at <= 2025-11-25 it keeps the synchronous ctx.elicit() request. Author-facing code (Resolve/Elicit) is unchanged. resolve_arguments becomes a resumable DAG walk: it reads ctx.input_responses / ctx.request_state, memoizes resolver outcomes by a process-stable module:qualname key, batches independent pending elicitations into one round, serializes dependent ones across rounds, and carries resolved outcomes in request_state so each resolver resolves once per logical call. Outcomes restored from request_state are re-validated into their model via the Elicit[T] return arm. request_state is client-trusted for now (HMAC sealing is a follow-up). Add a render_elicitation_schema helper to elicitation.py, MRTR-loop and codec tests, and document the transport in the migration guide. --- src/mcp/server/elicitation.py | 15 +- src/mcp/server/mcpserver/resolve.py | 259 ++++++++++++++++++---- src/mcp/server/mcpserver/tools/base.py | 11 +- tests/server/mcpserver/test_resolve.py | 296 ++++++++++++++++++++++--- 4 files changed, 501 insertions(+), 80 deletions(-) diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index c6faf0065..2f548f64e 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -87,6 +87,18 @@ def _validate_rendered_properties(json_schema: dict[str, Any]) -> None: ) from None +def render_elicitation_schema(schema: type[BaseModel]) -> dict[str, Any]: + """Render a model as the spec-valid `requested_schema` for an elicitation. + + Raises: + TypeError: If a field renders as something the spec's + `PrimitiveSchemaDefinition` does not accept. + """ + json_schema = schema.model_json_schema(schema_generator=_ElicitationJsonSchema) + _validate_rendered_properties(json_schema) + return json_schema + + async def elicit_with_validation( session: ServerSession, message: str, @@ -103,8 +115,7 @@ async def elicit_with_validation( For sensitive data like credentials or OAuth flows, use elicit_url() instead. """ - json_schema = schema.model_json_schema(schema_generator=_ElicitationJsonSchema) - _validate_rendered_properties(json_schema) + json_schema = render_elicitation_schema(schema) result = await session.elicit_form( message=message, diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py index 89843a716..2bc6e15d4 100644 --- a/src/mcp/server/mcpserver/resolve.py +++ b/src/mcp/server/mcpserver/resolve.py @@ -7,24 +7,40 @@ `Elicit[T]` to ask the client; the framework runs the elicitation and injects the answer. +The framework picks the elicitation transport from the negotiated protocol. At +>= 2026-07-28 it returns an `InputRequiredResult` carrying the batched questions +and resumes when the client retries with `input_responses`/`request_state` +(independent resolvers are asked in one round; a resolver depending on another's +answer is asked in a later round). At <= 2025-11-25 it issues a synchronous +`elicitation/create` request mid-call. Resolved outcomes are carried in +`request_state` across rounds, so each resolver resolves once per logical call. + Whether the consumer receives the unwrapped model or the full `ElicitationResult` union is decided by the consumer's annotation: - `Annotated[T, Resolve(fn)]` -> unwrapped `T`; decline/cancel aborts the call. - `Annotated[ElicitationResult[T], Resolve(fn)]` (or a specific member) -> the full outcome; the consumer branches on accept/decline/cancel. - -Each resolver runs at most once per `tools/call` (memoized by function identity). """ from __future__ import annotations import inspect +import json import typing from collections.abc import Callable, Hashable, Mapping from typing import Annotated, Any, Generic, cast, get_args, get_origin import anyio.to_thread +from mcp_types import ( + ElicitRequest, + ElicitRequestFormParams, + ElicitResult, + InputRequests, + InputRequiredResult, + InputResponses, +) +from mcp_types.version import LATEST_MODERN_VERSION, is_version_at_least from pydantic import BaseModel from typing_extensions import TypeVar @@ -33,6 +49,7 @@ CancelledElicitation, DeclinedElicitation, ElicitationResult, + render_elicitation_schema, ) from mcp.server.mcpserver.context import Context from mcp.server.mcpserver.exceptions import InvalidSignature, ToolError @@ -43,6 +60,11 @@ # The union members the framework injects when a consumer opts into the outcome. _ELICITATION_RESULT_MEMBERS = (AcceptedElicitation, DeclinedElicitation, CancelledElicitation) +# First protocol revision whose `tools/call` carries elicitation inside +# `InputRequiredResult` rather than as a standalone server-to-client request. +_INPUT_REQUIRED_VERSION = LATEST_MODERN_VERSION # "2026-07-28" +_STATE_VERSION = 1 + class Resolve: """Marker for `Annotated[T, Resolve(fn)]`: fill the parameter by running `fn`.""" @@ -79,10 +101,19 @@ def __init__(self, kind: str, resolve: Resolve | None = None, wants_union: bool class _ResolverPlan: """A resolver's parameters and whether it is async, analyzed once.""" - def __init__(self, fn: Callable[..., Any], params: dict[str, _ParamPlan], is_async: bool) -> None: + def __init__( + self, + fn: Callable[..., Any], + params: dict[str, _ParamPlan], + is_async: bool, + elicit_schema: type[BaseModel] | None, + ) -> None: self.fn = fn self.params = params self.is_async = is_async + # The `T` from the resolver's `Elicit[T]` return arm, if annotated. Used to + # re-validate an outcome restored from `request_state` into a model. + self.elicit_schema = elicit_schema def _type_hints(fn: Callable[..., Any]) -> dict[str, Any]: @@ -117,13 +148,6 @@ def find_resolved_parameters(fn: Callable[..., Any]) -> dict[str, tuple[Resolve, for name in inspect.signature(fn).parameters: annotation = hints.get(name) if get_origin(annotation) is not Annotated: - # A `Resolve` marker is only honored at the top level; flag (rather than - # silently drop) one buried in a union, e.g. `Annotated[T, Resolve(f)] | None`. - if _contains_resolve(annotation): - raise InvalidSignature( - f"Parameter {name!r} of {_resolver_name(fn)!r} wraps `Resolve(...)` in a " - "union; annotate the parameter directly as `Annotated[T, Resolve(...)]`" - ) continue type_arg, *metadata = get_args(annotation) marker = next((m for m in metadata if isinstance(m, Resolve)), None) @@ -132,23 +156,30 @@ def find_resolved_parameters(fn: Callable[..., Any]) -> dict[str, tuple[Resolve, return resolved -def _contains_resolve(annotation: Any) -> bool: - """True when a `Resolve` marker is nested inside `annotation` (e.g. a union member).""" - if get_origin(annotation) is Annotated: - return any(isinstance(m, Resolve) for m in get_args(annotation)[1:]) - return any(_contains_resolve(arg) for arg in get_args(annotation)) +def _elicit_return_schema(return_annotation: Any) -> type[BaseModel] | None: + """Extract `T` from a resolver return type's `Elicit[T]` arm, if present. + + Lets an outcome restored from `request_state` (a plain dict) be re-validated + into its model so dependent resolvers and tools receive a typed value. + """ + candidates = get_args(return_annotation) if get_origin(return_annotation) is not None else (return_annotation,) + for candidate in candidates: + if get_origin(candidate) is Elicit: + schema = get_args(candidate)[0] + if isinstance(schema, type) and issubclass(schema, BaseModel): # pragma: no branch + return schema + return None def _wants_union(type_arg: Any) -> bool: """True when `type_arg` is an `ElicitationResult` member (or a union of them). - Handles the subscripted `ElicitationResult[T]` alias (a `TypeAliasType` whose - union is on the origin's `__value__`), the bare `ElicitationResult` alias (the - `__value__` is on `type_arg` itself), an explicit `AcceptedElicitation[T] | ...` - union, and a single member. + Handles the bare `ElicitationResult[T]` alias (a `TypeAliasType` carrying the + union on `__value__`), an explicit `AcceptedElicitation[T] | ... ` union, and a + single member. """ - # Unwrap the `ElicitationResult` alias whether it is bare or subscripted. - value = getattr(type_arg, "__value__", None) or getattr(get_origin(type_arg), "__value__", None) + origin = get_origin(type_arg) + value = getattr(origin, "__value__", None) if value is not None: type_arg = value members = get_args(type_arg) if get_origin(type_arg) is not None else (type_arg,) @@ -217,7 +248,7 @@ def analyze(fn: Callable[..., Any], stack: tuple[Hashable, ...]) -> None: "expected a Context, an Annotated[_, Resolve(...)], or a tool argument by name" ) - plans[key] = _ResolverPlan(fn, params, is_async_callable(fn)) + plans[key] = _ResolverPlan(fn, params, is_async_callable(fn), _elicit_return_schema(hints.get("return"))) for dep in nested: analyze(dep, stack + (key,)) @@ -241,50 +272,111 @@ def _is_context_annotation(annotation: Any) -> bool: return any(isinstance(c, type) and issubclass(c, Context) for c in candidates) +class _Pending(Exception): + """Internal: a resolver needs client input not yet available this round.""" + + +class _Resolution: + """Per-`tools/call` resolution state, shared across the DAG walk. + + `input_required` selects the transport: at >= 2026-07-28 elicitations are + batched into `pending` and surfaced as an `InputRequiredResult`; at older + revisions each `Elicit` is answered synchronously via `ctx.elicit`. + """ + + def __init__( + self, + plans: Mapping[Hashable, _ResolverPlan], + tool_args: Mapping[str, Any], + context: Context[Any, Any], + input_required: bool, + ) -> None: + self.plans = plans + self.tool_args = tool_args + self.context = context + self.input_required = input_required + self.answers: InputResponses = context.input_responses or {} if input_required else {} + self.state = _decode_state(context.request_state) if input_required else {} + self.cache: dict[str, ElicitationResult[Any]] = {} + self.pending: dict[str, ElicitRequest] = {} + + +def _state_key(fn: Callable[..., Any]) -> str: + """Process-stable wire key for a resolver. + + `id`-based keys aren't stable across `input_required` rounds (a retry may land + on a different worker), so memoize and key `input_requests`/`request_state` by + the resolver's `module:qualname`. Two consumers of the same resolver therefore + share one cache entry, one question, and one stored outcome. + """ + return f"{getattr(fn, '__module__', '')}:{getattr(fn, '__qualname__', fn)}" + + async def resolve_arguments( resolved_params: Mapping[str, tuple[Resolve, bool]], plans: Mapping[Hashable, _ResolverPlan], tool_args: Mapping[str, Any], context: Context[Any, Any], -) -> dict[str, Any]: +) -> dict[str, Any] | InputRequiredResult: """Resolve every `Resolve`-marked tool parameter into a concrete value. - Each resolver runs at most once (memoized by function identity). Returns a - mapping of tool parameter name to the value to inject. + Returns the mapping of tool parameter name to injected value when every + resolver is satisfied. When a resolver still needs client input (and the + negotiated protocol is >= 2026-07-28), returns an `InputRequiredResult` + carrying the batched questions instead; the tool body is not run. + + Each resolver runs at most once per logical call - across multiple + `input_required` rounds, resolved outcomes are carried in `request_state`. Raises: ToolError: If an elicited value is declined or cancelled and the consumer asked for the unwrapped model (rather than the result union). """ - cache: dict[Hashable, ElicitationResult[Any]] = {} + res = _Resolution(plans, tool_args, context, uses_input_required(context.request_context.protocol_version)) injected: dict[str, Any] = {} for name, (marker, wants_union) in resolved_params.items(): - outcome = await _resolve(marker.fn, plans, tool_args, context, cache) + try: + outcome = await _resolve(marker.fn, res) + except _Pending: + continue injected[name] = outcome if wants_union else _unwrap(outcome, name) + + if res.pending: + return InputRequiredResult( + input_requests=cast("InputRequests", res.pending), + request_state=_encode_state(res.cache), + ) return injected -async def _resolve( - fn: Callable[..., Any], - plans: Mapping[Hashable, _ResolverPlan], - tool_args: Mapping[str, Any], - context: Context[Any, Any], - cache: dict[Hashable, ElicitationResult[Any]], -) -> ElicitationResult[Any]: - key = _resolver_key(fn) - if key in cache: - return cache[key] +async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResult[Any]: + """Resolve one resolver, memoized by its process-stable state key. + + Raises `_Pending` when the resolver (or one of its dependencies) needs client + input that has not arrived yet. + """ + key = _state_key(fn) + if key in res.cache: + return res.cache[key] + if key in res.pending: + # Already asked this round by another consumer; don't run the resolver again. + raise _Pending + + plan = res.plans[_resolver_key(fn)] + if key in res.state: + outcome = _outcome_from_state(res.state[key], plan.elicit_schema) + res.cache[key] = outcome + return outcome - plan = plans[key] kwargs: dict[str, Any] = {} for param_name, param_plan in plan.params.items(): if param_plan.kind == "context": - kwargs[param_name] = context + kwargs[param_name] = res.context elif param_plan.kind == "by_name": - kwargs[param_name] = tool_args[param_name] + kwargs[param_name] = res.tool_args[param_name] else: assert param_plan.resolve is not None - dep_outcome = await _resolve(param_plan.resolve.fn, plans, tool_args, context, cache) + dep_outcome = await _resolve(param_plan.resolve.fn, res) kwargs[param_name] = dep_outcome if param_plan.wants_union else _unwrap(dep_outcome, param_name) if plan.is_async: @@ -292,25 +384,102 @@ async def _resolve( else: result = await anyio.to_thread.run_sync(lambda: fn(**kwargs)) - outcome: ElicitationResult[Any] if isinstance(result, Elicit): - elicit = cast("Elicit[BaseModel]", result) - outcome = await context.elicit(elicit.message, elicit.schema) + outcome = await _elicit(cast("Elicit[BaseModel]", result), key, res) else: # A resolver may return any type (not just `BaseModel`); `model_construct` # wraps it as an accepted result without validating against the schema bound. outcome = cast("AcceptedElicitation[Any]", AcceptedElicitation.model_construct(data=result)) - cache[key] = outcome + res.cache[key] = outcome return outcome +async def _elicit(elicit: Elicit[BaseModel], key: str, res: _Resolution) -> ElicitationResult[Any]: + """Turn a resolver's `Elicit` into an outcome via the negotiated transport.""" + if not res.input_required: + return await res.context.elicit(elicit.message, elicit.schema) + + answer = res.answers.get(key) + if answer is None: + res.pending[key] = _elicit_request(elicit) + raise _Pending + if not isinstance(answer, ElicitResult): + raise ToolError(f"Resolver {key!r} received a non-elicitation response") + if answer.action == "accept" and answer.content is not None: + return AcceptedElicitation(data=elicit.schema.model_validate(answer.content)) + if answer.action == "decline": + return DeclinedElicitation() + return CancelledElicitation() + + def _unwrap(outcome: ElicitationResult[Any], name: str) -> Any: if isinstance(outcome, AcceptedElicitation): return outcome.data raise ToolError(f"Resolver for parameter {name!r} could not resolve: elicitation was {outcome.action}") +def uses_input_required(protocol_version: str | None) -> bool: + """True when this request must elicit via `InputRequiredResult` (>= 2026-07-28). + + Older revisions still carry a standalone `elicitation/create` server-to-client + request, so the framework keeps the synchronous `ctx.elicit()` path for them. + """ + return protocol_version is not None and is_version_at_least(protocol_version, _INPUT_REQUIRED_VERSION) + + +def _elicit_request(elicit: Elicit[Any]) -> ElicitRequest: + """Render an `Elicit[T]` as the embedded `elicitation/create` request for `input_requests`.""" + json_schema = render_elicitation_schema(elicit.schema) + return ElicitRequest(params=ElicitRequestFormParams(message=elicit.message, requested_schema=json_schema)) + + +def _decode_state(request_state: str | None) -> dict[str, dict[str, Any]]: + """Decode the per-call resolution progress from `request_state`. + + `request_state` is client-trusted (integrity sealing is a follow-up); decode + defensively and treat anything malformed as "no progress yet". + """ + if not request_state: + return {} + try: + decoded: Any = json.loads(request_state) + except json.JSONDecodeError: + return {} + if not isinstance(decoded, dict): + return {} + payload = cast("dict[str, Any]", decoded) + if payload.get("v") != _STATE_VERSION: + return {} + outcomes = payload.get("outcomes") + return cast("dict[str, dict[str, Any]]", outcomes) if isinstance(outcomes, dict) else {} + + +def _encode_state(outcomes: Mapping[str, ElicitationResult[Any]]) -> str: + """Encode resolved outcomes (keyed by resolver path) for the next round.""" + encoded: dict[str, dict[str, Any]] = {} + for path, outcome in outcomes.items(): + entry: dict[str, Any] = {"action": outcome.action} + if isinstance(outcome, AcceptedElicitation): + data = outcome.data + entry["data"] = data.model_dump(mode="json") if isinstance(data, BaseModel) else data + encoded[path] = entry + return json.dumps({"v": _STATE_VERSION, "outcomes": encoded}) + + +def _outcome_from_state(entry: Mapping[str, Any], schema: type[BaseModel] | None) -> ElicitationResult[Any]: + """Rebuild an `ElicitationResult` from a decoded `request_state` entry.""" + action = entry.get("action") + if action == "decline": + return DeclinedElicitation() + if action == "cancel": + return CancelledElicitation() + data = entry.get("data") + if schema is not None and isinstance(data, dict): + data = schema.model_validate(data) + return cast("AcceptedElicitation[Any]", AcceptedElicitation.model_construct(data=data)) + + __all__ = [ "Resolve", "Elicit", diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index 6aab3c777..50d28f574 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -4,7 +4,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Any -from mcp_types import Icon, ToolAnnotations +from mcp_types import Icon, InputRequiredResult, ToolAnnotations from pydantic import BaseModel, Field from mcp.server.mcpserver.exceptions import ToolError @@ -135,9 +135,12 @@ async def run( pre_validated: dict[str, Any] | None = None if self.resolved_params: pre_validated = self.fn_metadata.validate_arguments(arguments) - pass_directly |= await resolve_arguments( - self.resolved_params, self.resolver_plans, pre_validated, context - ) + resolved = await resolve_arguments(self.resolved_params, self.resolver_plans, pre_validated, context) + if isinstance(resolved, InputRequiredResult): + # A resolver still needs client input (>= 2026-07-28): surface the + # batched questions instead of running the tool body this round. + return self.fn_metadata.convert_result(resolved) if convert_result else resolved + pass_directly |= resolved result = await self.fn_metadata.call_fn_with_arg_validation( self.fn, diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py index 1f4f72408..61ec290d0 100644 --- a/tests/server/mcpserver/test_resolve.py +++ b/tests/server/mcpserver/test_resolve.py @@ -1,9 +1,17 @@ """Tests for resolver dependency injection (MRTR) on MCPServer tools.""" -from typing import Annotated, Any, Literal +from collections.abc import Callable +from typing import Annotated, Literal, cast import pytest -from mcp_types import ElicitRequestParams, ElicitResult, TextContent +from mcp_types import ( + CallToolResult, + ElicitRequestParams, + ElicitResult, + InputRequiredResult, + InputResponses, + TextContent, +) from pydantic import BaseModel, Field from mcp import Client @@ -19,7 +27,15 @@ Resolve, ) from mcp.server.mcpserver.exceptions import InvalidSignature -from mcp.server.mcpserver.resolve import _resolver_key, find_resolved_parameters +from mcp.server.mcpserver.resolve import ( + _decode_state, + _elicit_return_schema, + _encode_state, + _outcome_from_state, + _resolver_key, + find_resolved_parameters, + uses_input_required, +) from mcp.server.mcpserver.tools.base import Tool @@ -53,6 +69,36 @@ async def _text(client: Client, tool: str, args: dict[str, object]) -> str: return result.content[0].text +async def _drive_mrtr( + client: Client, + tool: str, + args: dict[str, object], + answer: Callable[[str, ElicitRequestParams], ElicitResult], + max_rounds: int = 10, +) -> CallToolResult: + """Drive the 2026-07-28 `input_required` loop to completion. + + Re-invokes `tools/call` with `input_responses`/`request_state` until the + server returns a final `CallToolResult`, fulfilling each pending request via + `answer(key, request_params)`. + """ + responses: InputResponses | None = None + state: str | None = None + for _ in range(max_rounds): + result = await client.call_tool( + tool, args, input_responses=responses, request_state=state, allow_input_required=True + ) + if isinstance(result, CallToolResult): + return result + assert isinstance(result, InputRequiredResult) + assert result.input_requests is not None + responses = { + key: answer(key, cast(ElicitRequestParams, req.params)) for key, req in result.input_requests.items() + } + state = result.request_state + raise AssertionError("input_required loop did not converge") # pragma: no cover + + @pytest.mark.anyio async def test_resolver_returns_value_directly_without_eliciting(): mcp = MCPServer(name="Direct") @@ -291,32 +337,6 @@ async def tool(login: Annotated[Login, Resolve(login)]) -> str: Tool.from_function(tool) -def test_resolve_marker_inside_a_union_raises_at_registration(): - async def login(ctx: Context) -> Login: - return Login(username="x") # pragma: no cover - - async def tool(login: Annotated[Login, Resolve(login)] | None = None) -> str: - return login.username if login else "" # pragma: no cover - - with pytest.raises(InvalidSignature, match="wraps `Resolve"): - Tool.from_function(tool) - - -def test_bare_elicitation_result_alias_wants_the_outcome_union(): - # The bare `ElicitationResult` alias (no `[T]` subscription) must still opt into - # the result union, not be treated as wanting the unwrapped model. - async def login(ctx: Context) -> Login: - return Login(username="x") # pragma: no cover - - async def tool(login: object) -> str: - return "x" # pragma: no cover - - bare_alias: Any = ElicitationResult - tool.__annotations__["login"] = Annotated[bare_alias, Resolve(login)] - (_, wants_union) = find_resolved_parameters(tool)["login"] - assert wants_union is True - - def test_resolve_marker_on_return_annotation_is_ignored(): async def login(ctx: Context) -> Login: return Login(username="x") # pragma: no cover @@ -569,3 +589,221 @@ async def callback(context: ClientRequestContext, params: ElicitRequestParams) - async with Client(mcp, mode="legacy", elicitation_callback=callback) as client: assert await _text(client, "delete_folder", {"path": "/docs"}) == expected assert ("/docs" in fs) is (expected != "deleted /docs") + + +@pytest.mark.anyio +async def test_input_required_first_round_returns_the_question(): + mcp, fs = _delete_folder_server() + fs["/docs"] = ["a.txt", "b.txt"] + + async with Client(mcp) as client: # mode="auto" negotiates 2026-07-28 + assert client.session.protocol_version == "2026-07-28" + result = await client.call_tool("delete_folder", {"path": "/docs"}, allow_input_required=True) + assert isinstance(result, InputRequiredResult) + assert result.input_requests is not None + (request,) = result.input_requests.values() + assert request.method == "elicitation/create" + assert "/docs has 2 file(s)" in request.params.message + assert result.request_state is not None + assert "/docs" in fs # nothing deleted before the answer arrives + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("action", "content", "expected"), + [ + ("accept", {"ok": True}, "deleted /docs"), + ("accept", {"ok": False}, "kept the folder"), + ("decline", None, "declined: folder not deleted"), + ("cancel", None, "cancelled: folder not deleted"), + ], +) +async def test_input_required_loop_handles_every_outcome( + action: Literal["accept", "decline", "cancel"], + content: dict[str, str | int | float | bool | list[str] | None] | None, + expected: str, +): + mcp, fs = _delete_folder_server() + fs["/docs"] = ["a.txt", "b.txt"] + + def answer(key: str, params: ElicitRequestParams) -> ElicitResult: + assert "/docs has 2 file(s)" in params.message + return ElicitResult(action=action, content=content) + + async with Client(mcp) as client: + result = await _drive_mrtr(client, "delete_folder", {"path": "/docs"}, answer) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == expected + assert ("/docs" in fs) is (expected != "deleted /docs") + + +@pytest.mark.anyio +async def test_input_required_empty_folder_completes_in_one_round(): + mcp, fs = _delete_folder_server() + fs["/empty"] = [] + + def never(key: str, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + raise AssertionError("should not elicit for an empty folder") + + async with Client(mcp) as client: + result = await _drive_mrtr(client, "delete_folder", {"path": "/empty"}, never) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "deleted /empty" + assert "/empty" not in fs + + +@pytest.mark.anyio +async def test_input_required_resolver_asks_and_consumes_then_never_reruns(): + mcp = MCPServer(name="ExactlyOnceMRTR") + counts = {"login": 0, "confirm": 0} + + async def login(ctx: Context) -> Login | Elicit[Login]: + counts["login"] += 1 + return Elicit("Username?", Login) + + async def confirm(login: Annotated[Login, Resolve(login)]) -> Elicit[Confirm]: + counts["confirm"] += 1 + return Elicit(f"As {login.username}?", Confirm) + + @mcp.tool() + async def act( + login: Annotated[Login, Resolve(login)], + confirm: Annotated[Confirm, Resolve(confirm)], + ) -> str: + return f"{login.username}:{confirm.ok}" + + def answer(key: str, params: ElicitRequestParams) -> ElicitResult: + if "Username" in params.message: + return ElicitResult(action="accept", content={"username": "octocat"}) + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp) as client: + result = await _drive_mrtr(client, "act", {}, answer) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "octocat:True" + + # An eliciting resolver runs twice - once to ask, once to consume the answer - + # then its outcome is carried in `request_state` and it never runs again. `login` + # asks in round 1 and is consumed in round 2; `confirm` (which depends on + # `login`) only forms its question once `login` is known, so it asks in round 2 + # and is consumed in round 3. Neither re-runs beyond consuming its own answer. + assert counts == {"login": 2, "confirm": 2} + + +@pytest.mark.anyio +async def test_input_required_batches_independent_elicits_in_one_round(): + mcp = MCPServer(name="BatchedMRTR") + + async def ask_name(ctx: Context) -> Elicit[Login]: + return Elicit("Name?", Login) + + async def ask_confirm(ctx: Context) -> Elicit[Confirm]: + return Elicit("Confirm?", Confirm) + + @mcp.tool() + async def both( + name: Annotated[Login, Resolve(ask_name)], + confirm: Annotated[Confirm, Resolve(ask_confirm)], + ) -> str: + return f"{name.username}:{confirm.ok}" + + def answer(key: str, params: ElicitRequestParams) -> ElicitResult: + if "Name" in params.message: + return ElicitResult(action="accept", content={"username": "octocat"}) + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp) as client: + # Both independent resolvers are asked together in the first round. + first = await client.call_tool("both", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + assert len(first.input_requests) == 2 + + result = await _drive_mrtr(client, "both", {}, answer) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "octocat:True" + + +def test_uses_input_required_version_gate(): + assert uses_input_required("2026-07-28") is True + assert uses_input_required("2025-11-25") is False + assert uses_input_required(None) is False + + +@pytest.mark.parametrize( + "request_state", + [ + None, + "", + "not json", + '{"v": 99, "outcomes": {}}', # wrong version + '{"v": 1}', # missing outcomes + '{"v": 1, "outcomes": []}', # outcomes not a dict + "[1, 2, 3]", # not an object + ], +) +def test_decode_state_tolerates_malformed_request_state(request_state: str | None): + assert _decode_state(request_state) == {} + + +def test_state_round_trips_accept_decline_cancel(): + outcomes: dict[str, ElicitationResult[BaseModel]] = { + "a": AcceptedElicitation(data=Login(username="octocat")), + "b": DeclinedElicitation(), + "c": CancelledElicitation(), + "d": AcceptedElicitation.model_construct(data="raw-token"), # non-model value + } + decoded = _decode_state(_encode_state(outcomes)) + + accepted = _outcome_from_state(decoded["a"], Login) + assert isinstance(accepted, AcceptedElicitation) and accepted.data == Login(username="octocat") + assert isinstance(_outcome_from_state(decoded["b"], None), DeclinedElicitation) + assert isinstance(_outcome_from_state(decoded["c"], None), CancelledElicitation) + raw = _outcome_from_state(decoded["d"], None) + assert isinstance(raw, AcceptedElicitation) and raw.data == "raw-token" + + +def test_elicit_return_schema_extraction(): + async def with_elicit(ctx: Context) -> Login | Elicit[Login]: + return Elicit("?", Login) # pragma: no cover + + async def without_elicit(ctx: Context) -> Login: + return Login(username="x") # pragma: no cover + + assert _elicit_return_schema(Login | Elicit[Login]) is Login + assert _elicit_return_schema(Login) is None + assert _elicit_return_schema(None) is None + + +@pytest.mark.anyio +async def test_non_elicitation_response_raises(): + from mcp_types import CreateMessageResult, TextContent + + mcp = MCPServer(name="WrongResponse") + + async def ask(ctx: Context) -> Elicit[Login]: + return Elicit("Name?", Login) + + @mcp.tool() + async def tool(name: Annotated[Login, Resolve(ask)]) -> str: + return name.username # pragma: no cover + + async with Client(mcp) as client: + r1 = await client.call_tool("tool", {}, allow_input_required=True) + assert isinstance(r1, InputRequiredResult) + assert r1.input_requests is not None + (key,) = r1.input_requests + # Answer with a sampling result instead of an elicitation result. + r2 = await client.call_tool( + "tool", + {}, + input_responses={ + key: CreateMessageResult(role="assistant", content=TextContent(type="text", text="x"), model="m") + }, + request_state=r1.request_state, + allow_input_required=True, + ) + assert isinstance(r2, CallToolResult) + assert r2.is_error + assert isinstance(r2.content[0], TextContent) + assert "non-elicitation response" in r2.content[0].text From 4cebc5804199772b440f756d62df104759c39ba9 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 26 Jun 2026 11:50:14 +0200 Subject: [PATCH 02/10] Remove casts from the input_required resolver path Replace every cast() with a checked or properly-typed alternative: model the request_state payload with pydantic (_State/_StateEntry) so the untrusted JSON is validated instead of cast; type _Resolution.pending as InputRequests so an ElicitRequest fits without a cast; add a _is_elicit TypeGuard and an _accepted helper that carry the right types; and narrow req.params via isinstance in the test helper. No behavior change. --- src/mcp/server/mcpserver/resolve.py | 95 +++++++++++++++----------- tests/server/mcpserver/test_resolve.py | 18 ++--- 2 files changed, 66 insertions(+), 47 deletions(-) diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py index 2bc6e15d4..b9a378c17 100644 --- a/src/mcp/server/mcpserver/resolve.py +++ b/src/mcp/server/mcpserver/resolve.py @@ -26,10 +26,9 @@ from __future__ import annotations import inspect -import json import typing from collections.abc import Callable, Hashable, Mapping -from typing import Annotated, Any, Generic, cast, get_args, get_origin +from typing import Annotated, Any, Generic, Literal, TypeGuard, get_args, get_origin import anyio.to_thread from mcp_types import ( @@ -41,7 +40,7 @@ InputResponses, ) from mcp_types.version import LATEST_MODERN_VERSION, is_version_at_least -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from typing_extensions import TypeVar from mcp.server.elicitation import ( @@ -298,7 +297,7 @@ def __init__( self.answers: InputResponses = context.input_responses or {} if input_required else {} self.state = _decode_state(context.request_state) if input_required else {} self.cache: dict[str, ElicitationResult[Any]] = {} - self.pending: dict[str, ElicitRequest] = {} + self.pending: InputRequests = {} def _state_key(fn: Callable[..., Any]) -> str: @@ -342,10 +341,7 @@ async def resolve_arguments( injected[name] = outcome if wants_union else _unwrap(outcome, name) if res.pending: - return InputRequiredResult( - input_requests=cast("InputRequests", res.pending), - request_state=_encode_state(res.cache), - ) + return InputRequiredResult(input_requests=res.pending, request_state=_encode_state(res.cache)) return injected @@ -379,23 +375,24 @@ async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResul dep_outcome = await _resolve(param_plan.resolve.fn, res) kwargs[param_name] = dep_outcome if param_plan.wants_union else _unwrap(dep_outcome, param_name) + result: Any if plan.is_async: result = await fn(**kwargs) else: result = await anyio.to_thread.run_sync(lambda: fn(**kwargs)) - if isinstance(result, Elicit): - outcome = await _elicit(cast("Elicit[BaseModel]", result), key, res) + if _is_elicit(result): + outcome = await _elicit(result, key, res) else: - # A resolver may return any type (not just `BaseModel`); `model_construct` - # wraps it as an accepted result without validating against the schema bound. - outcome = cast("AcceptedElicitation[Any]", AcceptedElicitation.model_construct(data=result)) + # A resolver may return any type (not just `BaseModel`), so accept it as the + # outcome without validating against the schema bound. + outcome = _accepted(result) res.cache[key] = outcome return outcome -async def _elicit(elicit: Elicit[BaseModel], key: str, res: _Resolution) -> ElicitationResult[Any]: +async def _elicit(elicit: Elicit[Any], key: str, res: _Resolution) -> ElicitationResult[Any]: """Turn a resolver's `Elicit` into an outcome via the negotiated transport.""" if not res.input_required: return await res.context.elicit(elicit.message, elicit.schema) @@ -419,6 +416,20 @@ def _unwrap(outcome: ElicitationResult[Any], name: str) -> Any: raise ToolError(f"Resolver for parameter {name!r} could not resolve: elicitation was {outcome.action}") +def _is_elicit(value: Any) -> TypeGuard[Elicit[Any]]: + """Runtime narrow of a resolver's return value to a (parameter-erased) `Elicit`.""" + return isinstance(value, Elicit) + + +def _accepted(data: Any) -> AcceptedElicitation[Any]: + """Wrap a resolved value as an accepted outcome without schema validation. + + A resolver may return any type (the schema bound only constrains `Elicit[T]`), + and a value restored from `request_state` is already validated. + """ + return AcceptedElicitation[Any].model_construct(data=data) + + def uses_input_required(protocol_version: str | None) -> bool: """True when this request must elicit via `InputRequiredResult` (>= 2026-07-28). @@ -434,50 +445,56 @@ def _elicit_request(elicit: Elicit[Any]) -> ElicitRequest: return ElicitRequest(params=ElicitRequestFormParams(message=elicit.message, requested_schema=json_schema)) -def _decode_state(request_state: str | None) -> dict[str, dict[str, Any]]: +class _StateEntry(BaseModel): + """One resolver's recorded outcome inside `request_state`.""" + + action: Literal["accept", "decline", "cancel"] + data: Any = None + + +class _State(BaseModel): + """The decoded `request_state`: resolver outcomes from earlier rounds.""" + + v: int + outcomes: dict[str, _StateEntry] = {} + + +def _decode_state(request_state: str | None) -> dict[str, _StateEntry]: """Decode the per-call resolution progress from `request_state`. - `request_state` is client-trusted (integrity sealing is a follow-up); decode - defensively and treat anything malformed as "no progress yet". + `request_state` is client-trusted (integrity sealing is a follow-up); validate + it through `_State` and treat anything malformed as "no progress yet". """ if not request_state: return {} try: - decoded: Any = json.loads(request_state) - except json.JSONDecodeError: - return {} - if not isinstance(decoded, dict): - return {} - payload = cast("dict[str, Any]", decoded) - if payload.get("v") != _STATE_VERSION: + state = _State.model_validate_json(request_state) + except ValidationError: return {} - outcomes = payload.get("outcomes") - return cast("dict[str, dict[str, Any]]", outcomes) if isinstance(outcomes, dict) else {} + return state.outcomes if state.v == _STATE_VERSION else {} def _encode_state(outcomes: Mapping[str, ElicitationResult[Any]]) -> str: """Encode resolved outcomes (keyed by resolver path) for the next round.""" - encoded: dict[str, dict[str, Any]] = {} + entries: dict[str, _StateEntry] = {} for path, outcome in outcomes.items(): - entry: dict[str, Any] = {"action": outcome.action} - if isinstance(outcome, AcceptedElicitation): - data = outcome.data - entry["data"] = data.model_dump(mode="json") if isinstance(data, BaseModel) else data - encoded[path] = entry - return json.dumps({"v": _STATE_VERSION, "outcomes": encoded}) + data = outcome.data if isinstance(outcome, AcceptedElicitation) else None + if isinstance(data, BaseModel): + data = data.model_dump(mode="json") + entries[path] = _StateEntry(action=outcome.action, data=data) + return _State(v=_STATE_VERSION, outcomes=entries).model_dump_json() -def _outcome_from_state(entry: Mapping[str, Any], schema: type[BaseModel] | None) -> ElicitationResult[Any]: +def _outcome_from_state(entry: _StateEntry, schema: type[BaseModel] | None) -> ElicitationResult[Any]: """Rebuild an `ElicitationResult` from a decoded `request_state` entry.""" - action = entry.get("action") - if action == "decline": + if entry.action == "decline": return DeclinedElicitation() - if action == "cancel": + if entry.action == "cancel": return CancelledElicitation() - data = entry.get("data") + data = entry.data if schema is not None and isinstance(data, dict): data = schema.model_validate(data) - return cast("AcceptedElicitation[Any]", AcceptedElicitation.model_construct(data=data)) + return _accepted(data) __all__ = [ diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py index 61ec290d0..5ef3e7177 100644 --- a/tests/server/mcpserver/test_resolve.py +++ b/tests/server/mcpserver/test_resolve.py @@ -1,11 +1,12 @@ """Tests for resolver dependency injection (MRTR) on MCPServer tools.""" from collections.abc import Callable -from typing import Annotated, Literal, cast +from typing import Annotated, Literal import pytest from mcp_types import ( CallToolResult, + ElicitRequestFormParams, ElicitRequestParams, ElicitResult, InputRequiredResult, @@ -73,7 +74,7 @@ async def _drive_mrtr( client: Client, tool: str, args: dict[str, object], - answer: Callable[[str, ElicitRequestParams], ElicitResult], + answer: Callable[[str, ElicitRequestFormParams], ElicitResult], max_rounds: int = 10, ) -> CallToolResult: """Drive the 2026-07-28 `input_required` loop to completion. @@ -92,9 +93,10 @@ async def _drive_mrtr( return result assert isinstance(result, InputRequiredResult) assert result.input_requests is not None - responses = { - key: answer(key, cast(ElicitRequestParams, req.params)) for key, req in result.input_requests.items() - } + responses = {} + for key, req in result.input_requests.items(): + assert isinstance(req.params, ElicitRequestFormParams) + responses[key] = answer(key, req.params) state = result.request_state raise AssertionError("input_required loop did not converge") # pragma: no cover @@ -626,7 +628,7 @@ async def test_input_required_loop_handles_every_outcome( mcp, fs = _delete_folder_server() fs["/docs"] = ["a.txt", "b.txt"] - def answer(key: str, params: ElicitRequestParams) -> ElicitResult: + def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: assert "/docs has 2 file(s)" in params.message return ElicitResult(action=action, content=content) @@ -672,7 +674,7 @@ async def act( ) -> str: return f"{login.username}:{confirm.ok}" - def answer(key: str, params: ElicitRequestParams) -> ElicitResult: + def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: if "Username" in params.message: return ElicitResult(action="accept", content={"username": "octocat"}) return ElicitResult(action="accept", content={"ok": True}) @@ -707,7 +709,7 @@ async def both( ) -> str: return f"{name.username}:{confirm.ok}" - def answer(key: str, params: ElicitRequestParams) -> ElicitResult: + def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: if "Name" in params.message: return ElicitResult(action="accept", content={"username": "octocat"}) return ElicitResult(action="accept", content={"ok": True}) From 31ccb807278ed8a319a3fbddbb4768f9459f3090 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 26 Jun 2026 13:28:55 +0200 Subject: [PATCH 03/10] Fix resolver MRTR edge cases from review - In-call cache keyed by _resolver_key (instance-distinct) again; _state_key adds id(__self__) for the wire key, so two instances of one bound method no longer collide and silently share an outcome. - resolve_arguments reads ctx.protocol_version (new Context property, None outside a request) instead of dereferencing request_context, so direct MCPServer.call_tool() works for tools whose resolvers never elicit. - request_state persists only elicited outcomes (always validated models); a resolver that resolves without eliciting is pure and re-runs each round. Fixes the json.dumps crash on non-serializable returns (datetime/set/...) and the dict-degradation of restored values. - _elicit_return_schema handles a bare Elicit[T] return (not only unions). - _INPUT_REQUIRED_VERSION pinned to '2026-07-28' instead of LATEST_MODERN_VERSION. - accept with no content raises ToolError instead of silently reporting cancel. - Independent nested resolver deps batch into one round (catch _Pending per dep). - Test cleanup: drop dead helpers, hoist CreateMessageResult import. Add regression tests for each; document the narrowed elicited-only persistence. --- src/mcp/server/mcpserver/context.py | 5 + src/mcp/server/mcpserver/resolve.py | 101 +++++++++----- tests/server/mcpserver/test_resolve.py | 186 +++++++++++++++++++++++-- 3 files changed, 249 insertions(+), 43 deletions(-) diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index 4d494db6e..82a6fa2b6 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -232,6 +232,11 @@ def request_id(self) -> str: """Get the unique ID for this request.""" return str(self.request_context.request_id) + @property + def protocol_version(self) -> str | None: + """The negotiated protocol version, or `None` outside of an active request.""" + return self._request_context.protocol_version if self._request_context is not None else None + @property def input_responses(self) -> InputResponses | None: """Client responses to a prior `InputRequiredResult.input_requests`. diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py index b9a378c17..1986adf4a 100644 --- a/src/mcp/server/mcpserver/resolve.py +++ b/src/mcp/server/mcpserver/resolve.py @@ -12,8 +12,9 @@ and resumes when the client retries with `input_responses`/`request_state` (independent resolvers are asked in one round; a resolver depending on another's answer is asked in a later round). At <= 2025-11-25 it issues a synchronous -`elicitation/create` request mid-call. Resolved outcomes are carried in -`request_state` across rounds, so each resolver resolves once per logical call. +`elicitation/create` request mid-call. Only *elicited* outcomes are carried in +`request_state` across rounds (so the user is asked each question once); a +resolver that returns a value without eliciting is pure and may re-run each round. Whether the consumer receives the unwrapped model or the full `ElicitationResult` union is decided by the consumer's annotation: @@ -26,6 +27,7 @@ from __future__ import annotations import inspect +import types import typing from collections.abc import Callable, Hashable, Mapping from typing import Annotated, Any, Generic, Literal, TypeGuard, get_args, get_origin @@ -39,7 +41,7 @@ InputRequiredResult, InputResponses, ) -from mcp_types.version import LATEST_MODERN_VERSION, is_version_at_least +from mcp_types.version import is_version_at_least from pydantic import BaseModel, ValidationError from typing_extensions import TypeVar @@ -61,7 +63,8 @@ # First protocol revision whose `tools/call` carries elicitation inside # `InputRequiredResult` rather than as a standalone server-to-client request. -_INPUT_REQUIRED_VERSION = LATEST_MODERN_VERSION # "2026-07-28" +# Pinned (not `LATEST_MODERN_VERSION`, which moves when newer revisions are added). +_INPUT_REQUIRED_VERSION = "2026-07-28" _STATE_VERSION = 1 @@ -158,10 +161,12 @@ def find_resolved_parameters(fn: Callable[..., Any]) -> dict[str, tuple[Resolve, def _elicit_return_schema(return_annotation: Any) -> type[BaseModel] | None: """Extract `T` from a resolver return type's `Elicit[T]` arm, if present. - Lets an outcome restored from `request_state` (a plain dict) be re-validated - into its model so dependent resolvers and tools receive a typed value. + Handles a bare `-> Elicit[T]` and a `-> T | Elicit[T]` union. Lets an elicited + outcome restored from `request_state` (a plain dict) be re-validated into its + model so dependent resolvers and tools receive a typed value. """ - candidates = get_args(return_annotation) if get_origin(return_annotation) is not None else (return_annotation,) + # A bare `Elicit[T]` is itself a candidate; a union contributes its members. + candidates = get_args(return_annotation) if _is_union(return_annotation) else (return_annotation,) for candidate in candidates: if get_origin(candidate) is Elicit: schema = get_args(candidate)[0] @@ -170,6 +175,10 @@ def _elicit_return_schema(return_annotation: Any) -> type[BaseModel] | None: return None +def _is_union(annotation: Any) -> bool: + return get_origin(annotation) in (typing.Union, types.UnionType) + + def _wants_union(type_arg: Any) -> bool: """True when `type_arg` is an `ElicitationResult` member (or a union of them). @@ -296,19 +305,26 @@ def __init__( self.input_required = input_required self.answers: InputResponses = context.input_responses or {} if input_required else {} self.state = _decode_state(context.request_state) if input_required else {} - self.cache: dict[str, ElicitationResult[Any]] = {} + # In-call dedup keyed by resolver identity (distinguishes two instances of + # the same bound method); `elicited` holds only outcomes that came from an + # elicitation, keyed by their wire key - these are what `request_state` + # persists, since pure resolvers are cheap to re-run each round. + self.cache: dict[Hashable, ElicitationResult[Any]] = {} + self.elicited: dict[str, ElicitationResult[Any]] = {} self.pending: InputRequests = {} def _state_key(fn: Callable[..., Any]) -> str: - """Process-stable wire key for a resolver. + """Process-stable wire key for a resolver's elicitation. - `id`-based keys aren't stable across `input_required` rounds (a retry may land - on a different worker), so memoize and key `input_requests`/`request_state` by - the resolver's `module:qualname`. Two consumers of the same resolver therefore - share one cache entry, one question, and one stored outcome. + `id(fn)` isn't stable across `input_required` rounds, so key `input_requests` / + `request_state` by `module:qualname`. Bound methods add their `__self__` id so + two instances of the same method get distinct questions and stored outcomes + (the registered `Resolve(...)` holds the instance for the call's lifetime). """ - return f"{getattr(fn, '__module__', '')}:{getattr(fn, '__qualname__', fn)}" + base = f"{getattr(fn, '__module__', '')}:{getattr(fn, '__qualname__', fn)!s}" + bound_self = getattr(fn, "__self__", None) + return f"{base}#{id(bound_self)}" if bound_self is not None else base async def resolve_arguments( @@ -324,14 +340,19 @@ async def resolve_arguments( negotiated protocol is >= 2026-07-28), returns an `InputRequiredResult` carrying the batched questions instead; the tool body is not run. - Each resolver runs at most once per logical call - across multiple - `input_required` rounds, resolved outcomes are carried in `request_state`. + An eliciting resolver asks its question once - its answer is carried in + `request_state` across rounds - while a resolver that resolves without + eliciting is pure and may re-run on each round. Raises: ToolError: If an elicited value is declined or cancelled and the consumer asked for the unwrapped model (rather than the result union). """ - res = _Resolution(plans, tool_args, context, uses_input_required(context.request_context.protocol_version)) + # `ctx.protocol_version` is `None` outside an active request: `MCPServer.call_tool()` + # called directly builds such a `Context`, and a tool whose resolvers never elicit + # must still work there. A missing version means the synchronous (non-input_required) + # transport, which never reaches a server-to-client request anyway. + res = _Resolution(plans, tool_args, context, uses_input_required(context.protocol_version)) injected: dict[str, Any] = {} for name, (marker, wants_union) in resolved_params.items(): try: @@ -341,30 +362,32 @@ async def resolve_arguments( injected[name] = outcome if wants_union else _unwrap(outcome, name) if res.pending: - return InputRequiredResult(input_requests=res.pending, request_state=_encode_state(res.cache)) + return InputRequiredResult(input_requests=res.pending, request_state=_encode_state(res.elicited)) return injected async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResult[Any]: - """Resolve one resolver, memoized by its process-stable state key. + """Resolve one resolver, deduped within the call by its resolver identity. Raises `_Pending` when the resolver (or one of its dependencies) needs client input that has not arrived yet. """ - key = _state_key(fn) - if key in res.cache: - return res.cache[key] - if key in res.pending: + cache_key = _resolver_key(fn) + if cache_key in res.cache: + return res.cache[cache_key] + + plan = res.plans[cache_key] + wire_key = _state_key(fn) + if wire_key in res.pending: # Already asked this round by another consumer; don't run the resolver again. raise _Pending - - plan = res.plans[_resolver_key(fn)] - if key in res.state: - outcome = _outcome_from_state(res.state[key], plan.elicit_schema) - res.cache[key] = outcome + if wire_key in res.state: + outcome = _outcome_from_state(res.state[wire_key], plan.elicit_schema) + res.cache[cache_key] = outcome return outcome kwargs: dict[str, Any] = {} + dep_pending = False for param_name, param_plan in plan.params.items(): if param_plan.kind == "context": kwargs[param_name] = res.context @@ -372,8 +395,16 @@ async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResul kwargs[param_name] = res.tool_args[param_name] else: assert param_plan.resolve is not None - dep_outcome = await _resolve(param_plan.resolve.fn, res) + try: + # Visit every dependency so independent ones that need input are all + # collected into `res.pending` and batched into a single round. + dep_outcome = await _resolve(param_plan.resolve.fn, res) + except _Pending: + dep_pending = True + continue kwargs[param_name] = dep_outcome if param_plan.wants_union else _unwrap(dep_outcome, param_name) + if dep_pending: + raise _Pending result: Any if plan.is_async: @@ -382,13 +413,15 @@ async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResul result = await anyio.to_thread.run_sync(lambda: fn(**kwargs)) if _is_elicit(result): - outcome = await _elicit(result, key, res) + outcome = await _elicit(result, wire_key, res) + res.elicited[wire_key] = outcome else: # A resolver may return any type (not just `BaseModel`), so accept it as the - # outcome without validating against the schema bound. + # outcome without validating against the schema bound. Plain outcomes are not + # persisted in `request_state`; the resolver re-runs next round instead. outcome = _accepted(result) - res.cache[key] = outcome + res.cache[cache_key] = outcome return outcome @@ -403,7 +436,9 @@ async def _elicit(elicit: Elicit[Any], key: str, res: _Resolution) -> Elicitatio raise _Pending if not isinstance(answer, ElicitResult): raise ToolError(f"Resolver {key!r} received a non-elicitation response") - if answer.action == "accept" and answer.content is not None: + if answer.action == "accept": + if answer.content is None: + raise ToolError(f"Resolver {key!r} received an accepted elicitation with no content") return AcceptedElicitation(data=elicit.schema.model_validate(answer.content)) if answer.action == "decline": return DeclinedElicitation() diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py index 5ef3e7177..ef0c1c92f 100644 --- a/tests/server/mcpserver/test_resolve.py +++ b/tests/server/mcpserver/test_resolve.py @@ -6,6 +6,7 @@ import pytest from mcp_types import ( CallToolResult, + CreateMessageResult, ElicitRequestFormParams, ElicitRequestParams, ElicitResult, @@ -766,21 +767,14 @@ def test_state_round_trips_accept_decline_cancel(): def test_elicit_return_schema_extraction(): - async def with_elicit(ctx: Context) -> Login | Elicit[Login]: - return Elicit("?", Login) # pragma: no cover - - async def without_elicit(ctx: Context) -> Login: - return Login(username="x") # pragma: no cover - - assert _elicit_return_schema(Login | Elicit[Login]) is Login - assert _elicit_return_schema(Login) is None + assert _elicit_return_schema(Elicit[Login]) is Login # bare Elicit[T] + assert _elicit_return_schema(Login | Elicit[Login]) is Login # union arm + assert _elicit_return_schema(Login) is None # no Elicit arm assert _elicit_return_schema(None) is None @pytest.mark.anyio async def test_non_elicitation_response_raises(): - from mcp_types import CreateMessageResult, TextContent - mcp = MCPServer(name="WrongResponse") async def ask(ctx: Context) -> Elicit[Login]: @@ -809,3 +803,175 @@ async def tool(name: Annotated[Login, Resolve(ask)]) -> str: assert r2.is_error assert isinstance(r2.content[0], TextContent) assert "non-elicitation response" in r2.content[0].text + + +@pytest.mark.anyio +async def test_direct_call_tool_with_non_eliciting_resolver(): + # `MCPServer.call_tool()` called directly builds a Context with no request, so + # `ctx.protocol_version` is None. A tool whose resolvers never elicit must still + # work there (regression: it used to raise "Context is not available"). + mcp = MCPServer(name="Direct") + + async def whoami(ctx: Context) -> Login: + return Login(username="direct") + + @mcp.tool() + async def tool(login: Annotated[Login, Resolve(whoami)]) -> str: + return login.username + + result = await mcp.call_tool("tool", {}, Context(mcp_server=mcp)) + assert isinstance(result, CallToolResult) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "direct" + + +@pytest.mark.anyio +async def test_two_instances_of_one_method_do_not_collide(): + mcp = MCPServer(name="Instances") + + class Service: + def __init__(self, name: str) -> None: + self.name = name + + async def who(self, ctx: Context) -> Login: + return Login(username=self.name) + + alice, bob = Service("alice"), Service("bob") + + @mcp.tool() + async def both( + a: Annotated[Login, Resolve(alice.who)], + b: Annotated[Login, Resolve(bob.who)], + ) -> str: + return f"{a.username},{b.username}" + + result = await mcp.call_tool("both", {}, Context(mcp_server=mcp)) + assert isinstance(result, CallToolResult) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "alice,bob" + + +@pytest.mark.anyio +async def test_non_serializable_sibling_resolver_does_not_break_rounds(): + from datetime import datetime + + mcp = MCPServer(name="NonSerializable") + + async def clock(ctx: Context) -> datetime: + return datetime(2026, 1, 1) + + async def ask(ctx: Context) -> Elicit[Confirm]: + return Elicit("ok?", Confirm) + + @mcp.tool() + async def act( + when: Annotated[datetime, Resolve(clock)], + confirm: Annotated[Confirm, Resolve(ask)], + ) -> str: + return f"{when.year}:{confirm.ok}" + + def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp) as client: + result = await _drive_mrtr(client, "act", {}, answer) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "2026:True" + + +@pytest.mark.anyio +async def test_bare_elicit_dependency_restored_as_model(): + # A `-> Elicit[Login]` (bare, no union) resolver feeds a dependent resolver. After + # the round-trip the dependency must come back as a Login model, not a raw dict. + mcp = MCPServer(name="BareElicitDep") + + async def login(ctx: Context) -> Elicit[Login]: + return Elicit("user?", Login) + + async def confirm(login: Annotated[Login, Resolve(login)]) -> Elicit[Confirm]: + return Elicit(f"as {login.username}?", Confirm) + + @mcp.tool() + async def act( + login: Annotated[Login, Resolve(login)], + confirm: Annotated[Confirm, Resolve(confirm)], + ) -> str: + return f"{login.username}:{confirm.ok}" + + def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: + if "user" in params.message: + return ElicitResult(action="accept", content={"username": "octocat"}) + assert "as octocat?" in params.message # proves login was a real model + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp) as client: + result = await _drive_mrtr(client, "act", {}, answer) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "octocat:True" + + +@pytest.mark.anyio +async def test_accept_with_no_content_is_an_error_not_a_cancel(): + mcp = MCPServer(name="AcceptNoContent") + + async def ask(ctx: Context) -> Elicit[Login]: + return Elicit("user?", Login) + + @mcp.tool() + async def tool(login: Annotated[Login, Resolve(ask)]) -> str: + return login.username # pragma: no cover + + async with Client(mcp) as client: + r1 = await client.call_tool("tool", {}, allow_input_required=True) + assert isinstance(r1, InputRequiredResult) + assert r1.input_requests is not None + (key,) = r1.input_requests + r2 = await client.call_tool( + "tool", + {}, + input_responses={key: ElicitResult(action="accept", content=None)}, + request_state=r1.request_state, + allow_input_required=True, + ) + assert isinstance(r2, CallToolResult) + assert r2.is_error + assert isinstance(r2.content[0], TextContent) + assert "no content" in r2.content[0].text + + +@pytest.mark.anyio +async def test_independent_nested_deps_batch_into_one_round(): + mcp = MCPServer(name="NestedBatch") + + async def ask_a(ctx: Context) -> Elicit[Login]: + return Elicit("A name?", Login) + + async def ask_b(ctx: Context) -> Elicit[Confirm]: + return Elicit("B confirm?", Confirm) + + # `combine` depends on two independent eliciting resolvers; both must be asked + # in the same round, not serialized across two InputRequiredResult rounds. + async def combine( + a: Annotated[Login, Resolve(ask_a)], + b: Annotated[Confirm, Resolve(ask_b)], + ) -> Login: + return Login(username=f"{a.username}:{b.ok}") + + @mcp.tool() + async def tool(combined: Annotated[Login, Resolve(combine)]) -> str: + return combined.username + + def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: + if "name" in params.message: + return ElicitResult(action="accept", content={"username": "octocat"}) + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp) as client: + first = await client.call_tool("tool", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + assert len(first.input_requests) == 2 # batched, not serialized + + result = await _drive_mrtr(client, "tool", {}, answer) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "octocat:True" From fad97fdf79d82c7d2c54a8a1f3abb7361a5a8993 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 26 Jun 2026 13:38:26 +0200 Subject: [PATCH 04/10] Fix two MRTR state bugs found by Codex review - Carry restored answers forward: an elicited outcome restored from request_state is now re-added to res.elicited, so in a 4+-round dependency chain an early answer is not dropped from request_state and re-asked on a later round. - Collision-free wire keys: assign each resolver a deterministic wire key at registration (module:qualname, disambiguated with #N when bases collide), so two distinct closures from one factory get separate questions/outcomes instead of sharing one. _state_key is now only the base-key source at registration. Add regression tests: a deep chain asserting an early answer is asked once, and factory closures asserting distinct wire keys and correct injected values. --- src/mcp/server/mcpserver/resolve.py | 23 +++++++- tests/server/mcpserver/test_resolve.py | 77 ++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 2 deletions(-) diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py index 1986adf4a..383963d68 100644 --- a/src/mcp/server/mcpserver/resolve.py +++ b/src/mcp/server/mcpserver/resolve.py @@ -109,6 +109,7 @@ def __init__( params: dict[str, _ParamPlan], is_async: bool, elicit_schema: type[BaseModel] | None, + wire_key: str, ) -> None: self.fn = fn self.params = params @@ -116,6 +117,10 @@ def __init__( # The `T` from the resolver's `Elicit[T]` return arm, if annotated. Used to # re-validate an outcome restored from `request_state` into a model. self.elicit_schema = elicit_schema + # Deterministic, collision-free key for this resolver's elicitation on the + # wire (`input_requests`/`request_state`). Assigned at registration so it is + # stable across rounds even when `module:qualname` collides (closures). + self.wire_key = wire_key def _type_hints(fn: Callable[..., Any]) -> dict[str, Any]: @@ -226,6 +231,9 @@ def build_resolver_plans( or a tool argument by name). """ plans: dict[Hashable, _ResolverPlan] = {} + # Count how many distinct resolvers share each `module:qualname` base so closures + # from one factory get distinct, deterministic wire keys (`base`, `base#1`, ...). + base_counts: dict[str, int] = {} def analyze(fn: Callable[..., Any], stack: tuple[Hashable, ...]) -> None: key = _resolver_key(fn) @@ -234,6 +242,11 @@ def analyze(fn: Callable[..., Any], stack: tuple[Hashable, ...]) -> None: if key in plans: return + base = _state_key(fn) + seen = base_counts.get(base, 0) + base_counts[base] = seen + 1 + wire_key = base if seen == 0 else f"{base}#{seen}" + hints = _type_hints(fn) sig = inspect.signature(fn) params: dict[str, _ParamPlan] = {} @@ -256,7 +269,9 @@ def analyze(fn: Callable[..., Any], stack: tuple[Hashable, ...]) -> None: "expected a Context, an Annotated[_, Resolve(...)], or a tool argument by name" ) - plans[key] = _ResolverPlan(fn, params, is_async_callable(fn), _elicit_return_schema(hints.get("return"))) + plans[key] = _ResolverPlan( + fn, params, is_async_callable(fn), _elicit_return_schema(hints.get("return")), wire_key + ) for dep in nested: analyze(dep, stack + (key,)) @@ -377,13 +392,17 @@ async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResul return res.cache[cache_key] plan = res.plans[cache_key] - wire_key = _state_key(fn) + wire_key = plan.wire_key if wire_key in res.pending: # Already asked this round by another consumer; don't run the resolver again. raise _Pending if wire_key in res.state: outcome = _outcome_from_state(res.state[wire_key], plan.elicit_schema) res.cache[cache_key] = outcome + # Carry the restored answer forward: if a later resolver is still pending, + # the next round's `request_state` is built from `res.elicited`, so an + # earlier answer must stay there or it would be dropped and re-asked. + res.elicited[wire_key] = outcome return outcome kwargs: dict[str, Any] = {} diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py index ef0c1c92f..2f9742cad 100644 --- a/tests/server/mcpserver/test_resolve.py +++ b/tests/server/mcpserver/test_resolve.py @@ -975,3 +975,80 @@ def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: result = await _drive_mrtr(client, "tool", {}, answer) assert isinstance(result.content[0], TextContent) assert result.content[0].text == "octocat:True" + + +@pytest.mark.anyio +async def test_deep_chain_keeps_early_answers_across_rounds(): + # A 4-round dependency chain where an early answer (A) must survive in + # request_state while later resolvers are asked. It must be asked exactly once. + mcp = MCPServer(name="DeepChain") + + async def ra(ctx: Context) -> Elicit[Login]: + return Elicit("A name?", Login) + + async def rb(a: Annotated[Login, Resolve(ra)]) -> Elicit[Confirm]: + return Elicit("B?", Confirm) + + async def rc(b: Annotated[Confirm, Resolve(rb)]) -> Elicit[Confirm]: + return Elicit("C?", Confirm) + + async def rd(c: Annotated[Confirm, Resolve(rc)]) -> Elicit[Confirm]: + return Elicit("D?", Confirm) + + # Depends on `ra` directly AND on `rd` (which transitively needs ra->rb->rc). + @mcp.tool() + async def tool( + a: Annotated[Login, Resolve(ra)], + d: Annotated[Confirm, Resolve(rd)], + ) -> str: + return f"{a.username}:{d.ok}" + + a_asks = 0 + + def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: + nonlocal a_asks + if "name" in params.message: + a_asks += 1 + return ElicitResult(action="accept", content={"username": "octocat"}) + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp) as client: + result = await _drive_mrtr(client, "tool", {}, answer) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "octocat:True" + assert a_asks == 1 # ra's answer survived in request_state; never re-asked + + +@pytest.mark.anyio +async def test_factory_closures_get_distinct_wire_keys(): + # Two resolvers from one factory share module:qualname; they must still get + # distinct questions and their own values (regression: they collided on the wire). + mcp = MCPServer(name="FactoryClosures") + + def make(label: str): + async def resolver(ctx: Context) -> Elicit[Login]: + return Elicit(f"{label}?", Login) + + return resolver + + ask_a, ask_b = make("A"), make("B") + + @mcp.tool() + async def tool( + a: Annotated[Login, Resolve(ask_a)], + b: Annotated[Login, Resolve(ask_b)], + ) -> str: + return f"{a.username},{b.username}" + + def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: + return ElicitResult(action="accept", content={"username": params.message[0]}) + + async with Client(mcp) as client: + first = await client.call_tool("tool", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + assert len(first.input_requests) == 2 # distinct keys, not collapsed to one + + result = await _drive_mrtr(client, "tool", {}, answer) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "A,B" From b6d6e07b515a569194ceedbf0b90fdff87d45e94 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 26 Jun 2026 17:21:47 +0200 Subject: [PATCH 05/10] Worker-stable wire keys; restore typed model for unannotated eliciting resolvers Review follow-ups on #2986: - _state_key carries no id(...): it is module:qualname (a callable object uses its type's), so request_state round-trips and resumes on any worker (stateless HTTP). Two resolvers sharing that base (method instances, factory closures) are already disambiguated deterministically at registration (#N), so dropping the id is safe. - An eliciting resolver whose annotation lacks an Elicit[T] arm has elicit_schema None; its answer restored from request_state is now re-validated against the live Elicit.schema (via _elicit consulting res.state) instead of injecting a raw dict. - Move the datetime import to module top (AGENTS.md). Add regression tests: an unannotated eliciting resolver in a multi-round flow, and worker-stable wire keys for method instances and callable objects. --- src/mcp/server/mcpserver/resolve.py | 55 +++++++++++----- tests/server/mcpserver/test_resolve.py | 86 +++++++++++++++++++++++++- 2 files changed, 124 insertions(+), 17 deletions(-) diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py index 383963d68..870a21024 100644 --- a/src/mcp/server/mcpserver/resolve.py +++ b/src/mcp/server/mcpserver/resolve.py @@ -155,6 +155,13 @@ def find_resolved_parameters(fn: Callable[..., Any]) -> dict[str, tuple[Resolve, for name in inspect.signature(fn).parameters: annotation = hints.get(name) if get_origin(annotation) is not Annotated: + # A `Resolve` marker is only honored at the top level; flag (rather than + # silently drop) one buried in a union, e.g. `Annotated[T, Resolve(f)] | None`. + if _contains_resolve(annotation): + raise InvalidSignature( + f"Parameter {name!r} of {_resolver_name(fn)!r} wraps `Resolve(...)` in a " + "union; annotate the parameter directly as `Annotated[T, Resolve(...)]`" + ) continue type_arg, *metadata = get_args(annotation) marker = next((m for m in metadata if isinstance(m, Resolve)), None) @@ -163,6 +170,13 @@ def find_resolved_parameters(fn: Callable[..., Any]) -> dict[str, tuple[Resolve, return resolved +def _contains_resolve(annotation: Any) -> bool: + """True when a `Resolve` marker is nested inside `annotation` (e.g. a union member).""" + if get_origin(annotation) is Annotated: + return any(isinstance(m, Resolve) for m in get_args(annotation)[1:]) + return any(_contains_resolve(arg) for arg in get_args(annotation)) + + def _elicit_return_schema(return_annotation: Any) -> type[BaseModel] | None: """Extract `T` from a resolver return type's `Elicit[T]` arm, if present. @@ -187,12 +201,13 @@ def _is_union(annotation: Any) -> bool: def _wants_union(type_arg: Any) -> bool: """True when `type_arg` is an `ElicitationResult` member (or a union of them). - Handles the bare `ElicitationResult[T]` alias (a `TypeAliasType` carrying the - union on `__value__`), an explicit `AcceptedElicitation[T] | ... ` union, and a - single member. + Handles the subscripted `ElicitationResult[T]` alias (a `TypeAliasType` whose + union is on the origin's `__value__`), the bare `ElicitationResult` alias (the + `__value__` is on `type_arg` itself), an explicit `AcceptedElicitation[T] | ...` + union, and a single member. """ - origin = get_origin(type_arg) - value = getattr(origin, "__value__", None) + # Unwrap the `ElicitationResult` alias whether it is bare or subscripted. + value = getattr(type_arg, "__value__", None) or getattr(get_origin(type_arg), "__value__", None) if value is not None: type_arg = value members = get_args(type_arg) if get_origin(type_arg) is not None else (type_arg,) @@ -330,16 +345,17 @@ def __init__( def _state_key(fn: Callable[..., Any]) -> str: - """Process-stable wire key for a resolver's elicitation. + """Worker-stable base wire key for a resolver, derived only from registration data. - `id(fn)` isn't stable across `input_required` rounds, so key `input_requests` / - `request_state` by `module:qualname`. Bound methods add their `__self__` id so - two instances of the same method get distinct questions and stored outcomes - (the registered `Resolve(...)` holds the instance for the call's lifetime). + `input_requests`/`request_state` must round-trip through the client and resume on + any worker (stateless HTTP), so the key carries no `id(...)`: it is the resolver's + `module:qualname` (a callable object uses its type's). Distinct resolvers that + share this base - two instances of one method, two closures from one factory - are + disambiguated deterministically by `build_resolver_plans` (`base`, `base#1`, ...). """ - base = f"{getattr(fn, '__module__', '')}:{getattr(fn, '__qualname__', fn)!s}" - bound_self = getattr(fn, "__self__", None) - return f"{base}#{id(bound_self)}" if bound_self is not None else base + qualname = getattr(fn, "__qualname__", None) or type(fn).__qualname__ + module = getattr(fn, "__module__", None) or type(fn).__module__ + return f"{module}:{qualname}" async def resolve_arguments( @@ -396,7 +412,11 @@ async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResul if wire_key in res.pending: # Already asked this round by another consumer; don't run the resolver again. raise _Pending - if wire_key in res.state: + # Restore a prior round's outcome directly only when its model is known from the + # `Elicit[T]` return arm. Without that (a resolver that elicits but isn't annotated + # `-> ... Elicit[T]`), fall through and re-run the resolver so `_elicit` can + # re-validate the stored answer against the live `Elicit.schema`. + if wire_key in res.state and (plan.elicit_schema is not None or res.state[wire_key].action != "accept"): outcome = _outcome_from_state(res.state[wire_key], plan.elicit_schema) res.cache[cache_key] = outcome # Carry the restored answer forward: if a later resolver is still pending, @@ -449,6 +469,13 @@ async def _elicit(elicit: Elicit[Any], key: str, res: _Resolution) -> Elicitatio if not res.input_required: return await res.context.elicit(elicit.message, elicit.schema) + # Answered in a prior round (restored without a known schema, e.g. an unannotated + # resolver): re-validate the stored entry against the live `Elicit.schema`. + if key in res.state and key not in res.answers: + outcome = _outcome_from_state(res.state[key], elicit.schema) + res.elicited[key] = outcome + return outcome + answer = res.answers.get(key) if answer is None: res.pending[key] = _elicit_request(elicit) diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py index 2f9742cad..c7ab33d0f 100644 --- a/tests/server/mcpserver/test_resolve.py +++ b/tests/server/mcpserver/test_resolve.py @@ -1,7 +1,8 @@ """Tests for resolver dependency injection (MRTR) on MCPServer tools.""" from collections.abc import Callable -from typing import Annotated, Literal +from datetime import datetime +from typing import Annotated, Any, Literal import pytest from mcp_types import ( @@ -340,6 +341,32 @@ async def tool(login: Annotated[Login, Resolve(login)]) -> str: Tool.from_function(tool) +def test_resolve_marker_inside_a_union_raises_at_registration(): + async def login(ctx: Context) -> Login: + return Login(username="x") # pragma: no cover + + async def tool(login: Annotated[Login, Resolve(login)] | None = None) -> str: + return login.username if login else "" # pragma: no cover + + with pytest.raises(InvalidSignature, match="wraps `Resolve"): + Tool.from_function(tool) + + +def test_bare_elicitation_result_alias_wants_the_outcome_union(): + # The bare `ElicitationResult` alias (no `[T]` subscription) must still opt into + # the result union, not be treated as wanting the unwrapped model. + async def login(ctx: Context) -> Login: + return Login(username="x") # pragma: no cover + + async def tool(login: object) -> str: + return "x" # pragma: no cover + + bare_alias: Any = ElicitationResult + tool.__annotations__["login"] = Annotated[bare_alias, Resolve(login)] + (_, wants_union) = find_resolved_parameters(tool)["login"] + assert wants_union is True + + def test_resolve_marker_on_return_annotation_is_ignored(): async def login(ctx: Context) -> Login: return Login(username="x") # pragma: no cover @@ -853,8 +880,6 @@ async def both( @pytest.mark.anyio async def test_non_serializable_sibling_resolver_does_not_break_rounds(): - from datetime import datetime - mcp = MCPServer(name="NonSerializable") async def clock(ctx: Context) -> datetime: @@ -1052,3 +1077,58 @@ def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: result = await _drive_mrtr(client, "tool", {}, answer) assert isinstance(result.content[0], TextContent) assert result.content[0].text == "A,B" + + +@pytest.mark.anyio +async def test_eliciting_resolver_without_elicit_arm_restores_a_typed_model(): + # A resolver annotated `-> Login` that actually returns `Elicit(...)` has no + # `Elicit[T]` return arm, so `elicit_schema` is None. Its answer, restored from + # request_state in a 3+ round flow, must still come back as a Login model (not a + # raw dict) so a dependent resolver/tool can use its attributes. + mcp = MCPServer(name="LyingAnnotation") + + # Annotated without an `Elicit[T]` return arm, so `elicit_schema` is None. + async def login(ctx: Context) -> object: + return Elicit("user?", Login) + + async def confirm(login: Annotated[Login, Resolve(login)]) -> Elicit[Confirm]: + return Elicit(f"as {login.username}?", Confirm) + + @mcp.tool() + async def act( + login: Annotated[Login, Resolve(login)], + confirm: Annotated[Confirm, Resolve(confirm)], + ) -> str: + return f"{login.username}:{confirm.ok}" + + def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: + if "user" in params.message: + return ElicitResult(action="accept", content={"username": "octocat"}) + assert "as octocat?" in params.message # login restored as a real model + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp) as client: + result = await _drive_mrtr(client, "act", {}, answer) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "octocat:True" + + +def test_wire_key_is_worker_stable_for_methods_and_callable_objects(): + from mcp.server.mcpserver.resolve import _state_key + + class Service: + async def token(self, ctx: Context) -> Login: + return Login(username="x") # pragma: no cover + + class CallableResolver: + async def __call__(self, ctx: Context) -> Login: + return Login(username="x") # pragma: no cover + + a, b = Service(), Service() + # No id(...) in the key: two instances of one method get the same base (they are + # disambiguated at registration, not here), and the key carries no memory address. + assert _state_key(a.token) == _state_key(b.token) + assert "#" not in _state_key(a.token) + assert _state_key(a.token).endswith("Service.token") + # Callable objects key by their type's qualname (they have no `__qualname__`). + assert _state_key(CallableResolver()).endswith("CallableResolver") From 49e3b5f5003c7ad99c152b95da6a1bdb61544feb Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:45:50 +0000 Subject: [PATCH 06/10] Harden the resolver input_required path; adapt its tests to the client driver - request_state entries are validated unconditionally when the resolver's model is known; an entry that fails validation is dropped and the question re-asked, matching _decode_state's malformed-means-no-progress stance. Accepted models are encoded by alias so the stored shape round-trips through the same validation the client's fresh answer passed. - Recorded state now wins over a re-sent answer on both restore paths (previously only resolvers without an Elicit[T] return arm honored a changed answer). - A fresh answer whose content does not match the requested schema raises an explicit ToolError instead of leaking pydantic internals. - Embedded elicitation is gated on the client's declared form-elicitation capability: a client that did not declare it gets -32021 with the requiredCapabilities payload (the spec says servers MUST NOT send inputRequests the client has not declared support for). - Legacy parity: elicit_with_validation reports accept-with-no-content explicitly; the unexpected-action arm is gone (the action literal makes the remainder provably cancel). - uses_input_required -> _uses_input_required (module-internal). - Tests drive the loop through the real client surfaces - manual wire shapes via session.call_tool(..., allow_input_required=True), outcomes end-to-end through Client.call_tool's auto-driver - plus regressions for forged state, declined-outcome persistence, unknown keys, aliased models, schema-mismatched answers and the rounds cap. --- src/mcp/server/elicitation.py | 15 +- src/mcp/server/mcpserver/resolve.py | 103 ++++- tests/server/mcpserver/test_resolve.py | 557 +++++++++++++++++++++---- 3 files changed, 558 insertions(+), 117 deletions(-) diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index 2f548f64e..e730c7bfb 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -114,6 +114,9 @@ async def elicit_with_validation( the user or automatically generating a response. For sensitive data like credentials or OAuth flows, use elicit_url() instead. + + Raises: + ValueError: If the client accepted the elicitation without supplying content. """ json_schema = render_elicitation_schema(schema) @@ -123,17 +126,15 @@ async def elicit_with_validation( related_request_id=related_request_id, ) - if result.action == "accept" and result.content is not None: + if result.action == "accept": + if result.content is None: + raise ValueError("Received an accepted elicitation with no content") # Validate and parse the content using the schema validated_data = schema.model_validate(result.content) return AcceptedElicitation(data=validated_data) - elif result.action == "decline": + if result.action == "decline": return DeclinedElicitation() - elif result.action == "cancel": - return CancelledElicitation() - else: # pragma: no cover - # This should never happen, but handle it just in case - raise ValueError(f"Unexpected elicitation action: {result.action}") + return CancelledElicitation() async def elicit_url( diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py index 870a21024..e033c5c3c 100644 --- a/src/mcp/server/mcpserver/resolve.py +++ b/src/mcp/server/mcpserver/resolve.py @@ -34,12 +34,17 @@ import anyio.to_thread from mcp_types import ( + MISSING_REQUIRED_CLIENT_CAPABILITY, + ClientCapabilities, + ElicitationCapability, ElicitRequest, ElicitRequestFormParams, ElicitResult, + FormElicitationCapability, InputRequests, InputRequiredResult, InputResponses, + MissingRequiredClientCapabilityErrorData, ) from mcp_types.version import is_version_at_least from pydantic import BaseModel, ValidationError @@ -55,6 +60,7 @@ from mcp.server.mcpserver.context import Context from mcp.server.mcpserver.exceptions import InvalidSignature, ToolError from mcp.shared._callable_inspection import is_async_callable +from mcp.shared.exceptions import MCPError T = TypeVar("T", bound=BaseModel) @@ -189,7 +195,7 @@ def _elicit_return_schema(return_annotation: Any) -> type[BaseModel] | None: for candidate in candidates: if get_origin(candidate) is Elicit: schema = get_args(candidate)[0] - if isinstance(schema, type) and issubclass(schema, BaseModel): # pragma: no branch + if isinstance(schema, type) and issubclass(schema, BaseModel): return schema return None @@ -383,7 +389,7 @@ async def resolve_arguments( # called directly builds such a `Context`, and a tool whose resolvers never elicit # must still work there. A missing version means the synchronous (non-input_required) # transport, which never reaches a server-to-client request anyway. - res = _Resolution(plans, tool_args, context, uses_input_required(context.protocol_version)) + res = _Resolution(plans, tool_args, context, _uses_input_required(context.protocol_version)) injected: dict[str, Any] = {} for name, (marker, wants_union) in resolved_params.items(): try: @@ -417,13 +423,10 @@ async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResul # `-> ... Elicit[T]`), fall through and re-run the resolver so `_elicit` can # re-validate the stored answer against the live `Elicit.schema`. if wire_key in res.state and (plan.elicit_schema is not None or res.state[wire_key].action != "accept"): - outcome = _outcome_from_state(res.state[wire_key], plan.elicit_schema) - res.cache[cache_key] = outcome - # Carry the restored answer forward: if a later resolver is still pending, - # the next round's `request_state` is built from `res.elicited`, so an - # earlier answer must stay there or it would be dropped and re-asked. - res.elicited[wire_key] = outcome - return outcome + outcome = _restore_outcome(res, wire_key, plan.elicit_schema) + if outcome is not None: + res.cache[cache_key] = outcome + return outcome kwargs: dict[str, Any] = {} dep_pending = False @@ -470,14 +473,16 @@ async def _elicit(elicit: Elicit[Any], key: str, res: _Resolution) -> Elicitatio return await res.context.elicit(elicit.message, elicit.schema) # Answered in a prior round (restored without a known schema, e.g. an unannotated - # resolver): re-validate the stored entry against the live `Elicit.schema`. - if key in res.state and key not in res.answers: - outcome = _outcome_from_state(res.state[key], elicit.schema) - res.elicited[key] = outcome + # resolver): re-validate the stored entry against the live `Elicit.schema`. A + # recorded outcome wins over a re-sent answer; an invalid entry self-deletes and + # falls through to the fresh answer (or to re-asking). + outcome = _restore_outcome(res, key, elicit.schema) + if outcome is not None: return outcome answer = res.answers.get(key) if answer is None: + _require_form_elicitation(res.context, key) res.pending[key] = _elicit_request(elicit) raise _Pending if not isinstance(answer, ElicitResult): @@ -485,7 +490,13 @@ async def _elicit(elicit: Elicit[Any], key: str, res: _Resolution) -> Elicitatio if answer.action == "accept": if answer.content is None: raise ToolError(f"Resolver {key!r} received an accepted elicitation with no content") - return AcceptedElicitation(data=elicit.schema.model_validate(answer.content)) + try: + data = elicit.schema.model_validate(answer.content) + except ValidationError as e: + raise ToolError( + f"Resolver {key!r} received an accepted elicitation whose content does not match the requested schema" + ) from e + return AcceptedElicitation(data=data) if answer.action == "decline": return DeclinedElicitation() return CancelledElicitation() @@ -511,7 +522,7 @@ def _accepted(data: Any) -> AcceptedElicitation[Any]: return AcceptedElicitation[Any].model_construct(data=data) -def uses_input_required(protocol_version: str | None) -> bool: +def _uses_input_required(protocol_version: str | None) -> bool: """True when this request must elicit via `InputRequiredResult` (>= 2026-07-28). Older revisions still carry a standalone `elicitation/create` server-to-client @@ -520,6 +531,31 @@ def uses_input_required(protocol_version: str | None) -> bool: return protocol_version is not None and is_version_at_least(protocol_version, _INPUT_REQUIRED_VERSION) +def _require_form_elicitation(context: Context[Any, Any], key: str) -> None: + """Assert the client declared form elicitation before queueing a question for it. + + The spec forbids sending an `input_requests` entry the client has not declared a + capability for. A bare `elicitation: {}` declaration (the only shape before modes + existed) counts as form support; an explicit url-only declaration does not. + + Raises: + MCPError: With code `MISSING_REQUIRED_CLIENT_CAPABILITY` and a + `requiredCapabilities` payload when form elicitation is not declared. + """ + capabilities = context.client_capabilities + elicitation = capabilities.elicitation if capabilities is not None else None + if elicitation is not None and (elicitation.form is not None or elicitation.url is None): + return + data = MissingRequiredClientCapabilityErrorData( + required_capabilities=ClientCapabilities(elicitation=ElicitationCapability(form=FormElicitationCapability())) + ) + raise MCPError( + code=MISSING_REQUIRED_CLIENT_CAPABILITY, + message=f"Client did not declare the form elicitation capability required by resolver {key!r}", + data=data.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + + def _elicit_request(elicit: Elicit[Any]) -> ElicitRequest: """Render an `Elicit[T]` as the embedded `elicitation/create` request for `input_requests`.""" json_schema = render_elicitation_schema(elicit.schema) @@ -561,23 +597,54 @@ def _encode_state(outcomes: Mapping[str, ElicitationResult[Any]]) -> str: for path, outcome in outcomes.items(): data = outcome.data if isinstance(outcome, AcceptedElicitation) else None if isinstance(data, BaseModel): - data = data.model_dump(mode="json") + # By alias: the stored shape must round-trip through + # `schema.model_validate` on restore, which expects the alias-keyed + # form the client answered with (the rendered schema is alias-keyed). + data = data.model_dump(mode="json", by_alias=True) entries[path] = _StateEntry(action=outcome.action, data=data) return _State(v=_STATE_VERSION, outcomes=entries).model_dump_json() def _outcome_from_state(entry: _StateEntry, schema: type[BaseModel] | None) -> ElicitationResult[Any]: - """Rebuild an `ElicitationResult` from a decoded `request_state` entry.""" + """Rebuild an `ElicitationResult` from a decoded `request_state` entry. + + Raises: + ValidationError: If `schema` is known and the entry's data does not + validate against it. + """ if entry.action == "decline": return DeclinedElicitation() if entry.action == "cancel": return CancelledElicitation() data = entry.data - if schema is not None and isinstance(data, dict): + if schema is not None: data = schema.model_validate(data) return _accepted(data) +def _restore_outcome(res: _Resolution, key: str, schema: type[BaseModel] | None) -> ElicitationResult[Any] | None: + """Restore `key`'s recorded outcome from a prior round, or `None` when absent. + + `request_state` is client-trusted, so an entry whose data fails validation gets + the `_decode_state` treatment - dropped as if no progress was recorded, so the + question is asked again - rather than surfacing a validation error. + + Carries a restored outcome forward in `res.elicited`: if a later resolver is + still pending, the next round's `request_state` is built from `res.elicited`, + so an earlier answer must stay there or it would be dropped and re-asked. + """ + entry = res.state.get(key) + if entry is None: + return None + try: + outcome = _outcome_from_state(entry, schema) + except ValidationError: + del res.state[key] + return None + res.elicited[key] = outcome + return outcome + + __all__ = [ "Resolve", "Elicit", diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py index c7ab33d0f..4bef8463f 100644 --- a/tests/server/mcpserver/test_resolve.py +++ b/tests/server/mcpserver/test_resolve.py @@ -1,11 +1,14 @@ """Tests for resolver dependency injection (MRTR) on MCPServer tools.""" +import json from collections.abc import Callable from datetime import datetime from typing import Annotated, Any, Literal +import anyio import pytest from mcp_types import ( + MISSING_REQUIRED_CLIENT_CAPABILITY, CallToolResult, CreateMessageResult, ElicitRequestFormParams, @@ -17,7 +20,7 @@ ) from pydantic import BaseModel, Field -from mcp import Client +from mcp import Client, InputRequiredRoundsExceededError from mcp.client import ClientRequestContext from mcp.server.mcpserver import ( AcceptedElicitation, @@ -36,10 +39,12 @@ _encode_state, _outcome_from_state, _resolver_key, + _state_key, + _uses_input_required, find_resolved_parameters, - uses_input_required, ) from mcp.server.mcpserver.tools.base import Tool +from mcp.shared.exceptions import MCPError class Login(BaseModel): @@ -50,6 +55,10 @@ class Confirm(BaseModel): ok: bool +class Handle(BaseModel): + user_name: str = Field(alias="userName") + + async def _alias_login(ctx: Context) -> Login: return Login(username="x") # pragma: no cover - only the signature is inspected @@ -65,6 +74,12 @@ async def _decline(context: ClientRequestContext, params: ElicitRequestParams) - return ElicitResult(action="decline") +async def _never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + # Declares the form elicitation capability for clients that drive the + # input_required loop manually; the auto-driver never invokes it. + raise AssertionError("should not be called") + + async def _text(client: Client, tool: str, args: dict[str, object]) -> str: result = await client.call_tool(tool, args) assert len(result.content) == 1 @@ -72,35 +87,16 @@ async def _text(client: Client, tool: str, args: dict[str, object]) -> str: return result.content[0].text -async def _drive_mrtr( - client: Client, - tool: str, - args: dict[str, object], - answer: Callable[[str, ElicitRequestFormParams], ElicitResult], - max_rounds: int = 10, -) -> CallToolResult: - """Drive the 2026-07-28 `input_required` loop to completion. - - Re-invokes `tools/call` with `input_responses`/`request_state` until the - server returns a final `CallToolResult`, fulfilling each pending request via - `answer(key, request_params)`. - """ - responses: InputResponses | None = None - state: str | None = None - for _ in range(max_rounds): - result = await client.call_tool( - tool, args, input_responses=responses, request_state=state, allow_input_required=True - ) - if isinstance(result, CallToolResult): - return result - assert isinstance(result, InputRequiredResult) - assert result.input_requests is not None - responses = {} - for key, req in result.input_requests.items(): - assert isinstance(req.params, ElicitRequestFormParams) - responses[key] = answer(key, req.params) - state = result.request_state - raise AssertionError("input_required loop did not converge") # pragma: no cover +def _answer_round( + result: InputRequiredResult, answer: Callable[[str, ElicitRequestFormParams], ElicitResult] +) -> InputResponses: + """Fulfil every question in one `InputRequiredResult` round via `answer(key, request_params)`.""" + assert result.input_requests is not None + responses: InputResponses = {} + for key, req in result.input_requests.items(): + assert isinstance(req.params, ElicitRequestFormParams) + responses[key] = answer(key, req.params) + return responses @pytest.mark.anyio @@ -626,9 +622,9 @@ async def test_input_required_first_round_returns_the_question(): mcp, fs = _delete_folder_server() fs["/docs"] = ["a.txt", "b.txt"] - async with Client(mcp) as client: # mode="auto" negotiates 2026-07-28 + async with Client(mcp, elicitation_callback=_never) as client: # mode="auto" negotiates 2026-07-28 assert client.session.protocol_version == "2026-07-28" - result = await client.call_tool("delete_folder", {"path": "/docs"}, allow_input_required=True) + result = await client.session.call_tool("delete_folder", {"path": "/docs"}, allow_input_required=True) assert isinstance(result, InputRequiredResult) assert result.input_requests is not None (request,) = result.input_requests.values() @@ -653,30 +649,32 @@ async def test_input_required_loop_handles_every_outcome( content: dict[str, str | int | float | bool | list[str] | None] | None, expected: str, ): + # End-to-end at 2026-07-28: the client's auto-driver answers the embedded + # elicitation through the ordinary `elicitation_callback` and retries. mcp, fs = _delete_folder_server() fs["/docs"] = ["a.txt", "b.txt"] - def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: assert "/docs has 2 file(s)" in params.message return ElicitResult(action=action, content=content) - async with Client(mcp) as client: - result = await _drive_mrtr(client, "delete_folder", {"path": "/docs"}, answer) + async with Client(mcp, elicitation_callback=callback) as client: # mode="auto" negotiates 2026-07-28 + result = await client.call_tool("delete_folder", {"path": "/docs"}) assert isinstance(result.content[0], TextContent) assert result.content[0].text == expected assert ("/docs" in fs) is (expected != "deleted /docs") @pytest.mark.anyio -async def test_input_required_empty_folder_completes_in_one_round(): +async def test_input_required_empty_folder_completes_without_eliciting(): mcp, fs = _delete_folder_server() fs["/empty"] = [] - def never(key: str, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover + async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover raise AssertionError("should not elicit for an empty folder") - async with Client(mcp) as client: - result = await _drive_mrtr(client, "delete_folder", {"path": "/empty"}, never) + async with Client(mcp, elicitation_callback=never) as client: + result = await client.call_tool("delete_folder", {"path": "/empty"}) assert isinstance(result.content[0], TextContent) assert result.content[0].text == "deleted /empty" assert "/empty" not in fs @@ -702,16 +700,22 @@ async def act( ) -> str: return f"{login.username}:{confirm.ok}" - def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: + asked: list[str] = [] + + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + asked.append(params.message) if "Username" in params.message: return ElicitResult(action="accept", content={"username": "octocat"}) return ElicitResult(action="accept", content={"ok": True}) - async with Client(mcp) as client: - result = await _drive_mrtr(client, "act", {}, answer) + async with Client(mcp, elicitation_callback=callback) as client: + result = await client.call_tool("act", {}) assert isinstance(result.content[0], TextContent) assert result.content[0].text == "octocat:True" + # `confirm` can only form its question from `login`'s answer, so the auto-driver + # sees the questions in two successive rounds and answers each exactly once. + assert asked == ["Username?", "As octocat?"] # An eliciting resolver runs twice - once to ask, once to consume the answer - # then its outcome is carried in `request_state` and it never runs again. `login` # asks in round 1 and is consumed in round 2; `confirm` (which depends on @@ -742,22 +746,74 @@ def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: return ElicitResult(action="accept", content={"username": "octocat"}) return ElicitResult(action="accept", content={"ok": True}) - async with Client(mcp) as client: + async with Client(mcp, elicitation_callback=_never) as client: # Both independent resolvers are asked together in the first round. - first = await client.call_tool("both", {}, allow_input_required=True) + first = await client.session.call_tool("both", {}, allow_input_required=True) assert isinstance(first, InputRequiredResult) assert first.input_requests is not None assert len(first.input_requests) == 2 - result = await _drive_mrtr(client, "both", {}, answer) + # Answering both and echoing `request_state` completes in a single retry. + final = await client.session.call_tool( + "both", + {}, + input_responses=_answer_round(first, answer), + request_state=first.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "octocat:True" + + +@pytest.mark.anyio +async def test_auto_driver_answers_independent_questions_in_a_single_round(): + # The pure `count_round` resolver is never persisted in `request_state`, so it + # re-runs on every round: its run count is the number of rounds the call took. + mcp = MCPServer(name="AutoBatch") + rounds = 0 + + async def count_round(ctx: Context) -> int: + nonlocal rounds + rounds += 1 + return rounds + + async def ask_name(ctx: Context) -> Elicit[Login]: + return Elicit("Name?", Login) + + async def ask_confirm(ctx: Context) -> Elicit[Confirm]: + return Elicit("Confirm?", Confirm) + + @mcp.tool() + async def both( + round_no: Annotated[int, Resolve(count_round)], + name: Annotated[Login, Resolve(ask_name)], + confirm: Annotated[Confirm, Resolve(ask_confirm)], + ) -> str: + return f"{name.username}:{confirm.ok}" + + asked: list[str] = [] + + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + asked.append(params.message) + if "Name" in params.message: + return ElicitResult(action="accept", content={"username": "octocat"}) + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp, elicitation_callback=callback) as client: + result = await client.call_tool("both", {}) assert isinstance(result.content[0], TextContent) assert result.content[0].text == "octocat:True" + # The driver dispatches batched questions concurrently, so order is unspecified. + assert sorted(asked) == ["Confirm?", "Name?"] # both questions, each exactly once + assert rounds == 2 # one question round, then the completing round + def test_uses_input_required_version_gate(): - assert uses_input_required("2026-07-28") is True - assert uses_input_required("2025-11-25") is False - assert uses_input_required(None) is False + assert _uses_input_required("2026-07-28") is True + assert _uses_input_required("2025-11-25") is False + assert _uses_input_required(None) is False @pytest.mark.parametrize( @@ -798,6 +854,10 @@ def test_elicit_return_schema_extraction(): assert _elicit_return_schema(Login | Elicit[Login]) is Login # union arm assert _elicit_return_schema(Login) is None # no Elicit arm assert _elicit_return_schema(None) is None + # The bound on `Elicit`'s parameter is unenforced at runtime, so a non-model + # subscription is constructible and must yield no schema rather than crash. + unbounded_elicit: Any = Elicit + assert _elicit_return_schema(unbounded_elicit[int]) is None @pytest.mark.anyio @@ -811,13 +871,13 @@ async def ask(ctx: Context) -> Elicit[Login]: async def tool(name: Annotated[Login, Resolve(ask)]) -> str: return name.username # pragma: no cover - async with Client(mcp) as client: - r1 = await client.call_tool("tool", {}, allow_input_required=True) + async with Client(mcp, elicitation_callback=_never) as client: + r1 = await client.session.call_tool("tool", {}, allow_input_required=True) assert isinstance(r1, InputRequiredResult) assert r1.input_requests is not None (key,) = r1.input_requests # Answer with a sampling result instead of an elicitation result. - r2 = await client.call_tool( + r2 = await client.session.call_tool( "tool", {}, input_responses={ @@ -895,11 +955,11 @@ async def act( ) -> str: return f"{when.year}:{confirm.ok}" - def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: return ElicitResult(action="accept", content={"ok": True}) - async with Client(mcp) as client: - result = await _drive_mrtr(client, "act", {}, answer) + async with Client(mcp, elicitation_callback=callback) as client: + result = await client.call_tool("act", {}) assert isinstance(result.content[0], TextContent) assert result.content[0].text == "2026:True" @@ -923,22 +983,49 @@ async def act( ) -> str: return f"{login.username}:{confirm.ok}" - def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: if "user" in params.message: return ElicitResult(action="accept", content={"username": "octocat"}) assert "as octocat?" in params.message # proves login was a real model return ElicitResult(action="accept", content={"ok": True}) - async with Client(mcp) as client: - result = await _drive_mrtr(client, "act", {}, answer) + async with Client(mcp, elicitation_callback=callback) as client: + result = await client.call_tool("act", {}) assert isinstance(result.content[0], TextContent) assert result.content[0].text == "octocat:True" @pytest.mark.anyio -async def test_accept_with_no_content_is_an_error_not_a_cancel(): +@pytest.mark.parametrize("mode", ["legacy", "auto"]) +async def test_accept_with_no_content_is_an_error_not_a_cancel(mode: Literal["legacy", "auto"]): + # Both transports must agree: mode="legacy" elicits synchronously mid-call, + # mode="auto" rides the 2026-07-28 input_required loop. mcp = MCPServer(name="AcceptNoContent") + async def ask(ctx: Context) -> Elicit[Login]: + return Elicit("user?", Login) + + @mcp.tool() + async def tool(login: Annotated[Login, Resolve(ask)]) -> str: + return login.username # pragma: no cover + + async def empty_accept(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="accept", content=None) + + async with Client(mcp, mode=mode, elicitation_callback=empty_accept) as client: + result = await client.call_tool("tool", {}) + assert result.is_error + assert isinstance(result.content[0], TextContent) + assert "no content" in result.content[0].text + + +@pytest.mark.anyio +async def test_eliciting_tool_without_client_capability_is_a_protocol_error(): + # The server must not send an `input_requests` entry the client has not declared + # capability for: with no `elicitation` declared (no callback), the call fails as + # a -32021 protocol error, not a CallToolResult execution failure. + mcp = MCPServer(name="NoElicitationCapability") + async def ask(ctx: Context) -> Elicit[Login]: return Elicit("user?", Login) @@ -947,21 +1034,11 @@ async def tool(login: Annotated[Login, Resolve(ask)]) -> str: return login.username # pragma: no cover async with Client(mcp) as client: - r1 = await client.call_tool("tool", {}, allow_input_required=True) - assert isinstance(r1, InputRequiredResult) - assert r1.input_requests is not None - (key,) = r1.input_requests - r2 = await client.call_tool( - "tool", - {}, - input_responses={key: ElicitResult(action="accept", content=None)}, - request_state=r1.request_state, - allow_input_required=True, - ) - assert isinstance(r2, CallToolResult) - assert r2.is_error - assert isinstance(r2.content[0], TextContent) - assert "no content" in r2.content[0].text + with pytest.raises(MCPError) as exc_info: + await client.session.call_tool("tool", {}, allow_input_required=True) + assert exc_info.value.code == MISSING_REQUIRED_CLIENT_CAPABILITY + assert exc_info.value.error.data is not None + assert "elicitation" in exc_info.value.error.data["requiredCapabilities"] @pytest.mark.anyio @@ -991,15 +1068,22 @@ def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: return ElicitResult(action="accept", content={"username": "octocat"}) return ElicitResult(action="accept", content={"ok": True}) - async with Client(mcp) as client: - first = await client.call_tool("tool", {}, allow_input_required=True) + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("tool", {}, allow_input_required=True) assert isinstance(first, InputRequiredResult) assert first.input_requests is not None assert len(first.input_requests) == 2 # batched, not serialized - result = await _drive_mrtr(client, "tool", {}, answer) - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "octocat:True" + final = await client.session.call_tool( + "tool", + {}, + input_responses=_answer_round(first, answer), + request_state=first.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "octocat:True" @pytest.mark.anyio @@ -1030,15 +1114,15 @@ async def tool( a_asks = 0 - def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: nonlocal a_asks if "name" in params.message: a_asks += 1 return ElicitResult(action="accept", content={"username": "octocat"}) return ElicitResult(action="accept", content={"ok": True}) - async with Client(mcp) as client: - result = await _drive_mrtr(client, "tool", {}, answer) + async with Client(mcp, elicitation_callback=callback) as client: + result = await client.call_tool("tool", {}) assert isinstance(result.content[0], TextContent) assert result.content[0].text == "octocat:True" assert a_asks == 1 # ra's answer survived in request_state; never re-asked @@ -1068,15 +1152,22 @@ async def tool( def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: return ElicitResult(action="accept", content={"username": params.message[0]}) - async with Client(mcp) as client: - first = await client.call_tool("tool", {}, allow_input_required=True) + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("tool", {}, allow_input_required=True) assert isinstance(first, InputRequiredResult) assert first.input_requests is not None assert len(first.input_requests) == 2 # distinct keys, not collapsed to one - result = await _drive_mrtr(client, "tool", {}, answer) - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "A,B" + final = await client.session.call_tool( + "tool", + {}, + input_responses=_answer_round(first, answer), + request_state=first.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "A,B" @pytest.mark.anyio @@ -1101,21 +1192,19 @@ async def act( ) -> str: return f"{login.username}:{confirm.ok}" - def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: if "user" in params.message: return ElicitResult(action="accept", content={"username": "octocat"}) assert "as octocat?" in params.message # login restored as a real model return ElicitResult(action="accept", content={"ok": True}) - async with Client(mcp) as client: - result = await _drive_mrtr(client, "act", {}, answer) + async with Client(mcp, elicitation_callback=callback) as client: + result = await client.call_tool("act", {}) assert isinstance(result.content[0], TextContent) assert result.content[0].text == "octocat:True" def test_wire_key_is_worker_stable_for_methods_and_callable_objects(): - from mcp.server.mcpserver.resolve import _state_key - class Service: async def token(self, ctx: Context) -> Login: return Login(username="x") # pragma: no cover @@ -1132,3 +1221,287 @@ async def __call__(self, ctx: Context) -> Login: assert _state_key(a.token).endswith("Service.token") # Callable objects key by their type's qualname (they have no `__qualname__`). assert _state_key(CallableResolver()).endswith("CallableResolver") + + +@pytest.mark.anyio +async def test_declined_outcome_persists_in_request_state_and_is_not_reasked(): + # A decline is recorded in `request_state` just like an accept: RB elicits only + # after seeing RA's decline, so RA's outcome must survive into the round that + # answers RB without RA being asked again. + mcp = MCPServer(name="DeclinePersists") + + async def ra(ctx: Context) -> Elicit[Login]: + return Elicit("user?", Login) + + async def rb(a: Annotated[ElicitationResult[Login], Resolve(ra)]) -> Elicit[Confirm]: + assert isinstance(a, DeclinedElicitation) + return Elicit("proceed anonymously?", Confirm) + + @mcp.tool() + async def act( + a: Annotated[ElicitationResult[Login], Resolve(ra)], + c: Annotated[Confirm, Resolve(rb)], + ) -> str: + assert isinstance(a, DeclinedElicitation) + return f"anonymous:{c.ok}" + + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("act", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + (ra_key,) = first.input_requests + + second = await client.session.call_tool( + "act", + {}, + input_responses={ra_key: ElicitResult(action="decline")}, + request_state=first.request_state, + allow_input_required=True, + ) + assert isinstance(second, InputRequiredResult) + assert second.input_requests is not None + (rb_key,) = second.input_requests # only RB's question; RA is not re-asked + assert rb_key != ra_key + assert _decode_state(second.request_state)[ra_key].action == "decline" + + final = await client.session.call_tool( + "act", + {}, + input_responses={rb_key: ElicitResult(action="accept", content={"ok": True})}, + request_state=second.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "anonymous:True" + + +@pytest.mark.anyio +async def test_unknown_response_keys_and_ghost_state_entries_are_ignored(): + # `input_responses` keys the server never asked for and `request_state` outcome + # entries matching no resolver are tolerated (both are client-supplied), and the + # ghost state entry is not echoed into any later round's `request_state`. + mcp = MCPServer(name="GhostKeys") + + async def ra(ctx: Context) -> Elicit[Login]: + return Elicit("user?", Login) + + async def rb(a: Annotated[Login, Resolve(ra)]) -> Elicit[Confirm]: + return Elicit(f"as {a.username}?", Confirm) + + @mcp.tool() + async def act( + a: Annotated[Login, Resolve(ra)], + c: Annotated[Confirm, Resolve(rb)], + ) -> str: + return f"{a.username}:{c.ok}" + + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("act", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + assert first.request_state is not None + (ra_key,) = first.input_requests + + spliced = json.loads(first.request_state) + spliced["outcomes"]["ghost"] = {"action": "accept", "data": {"username": "spooky"}} + second = await client.session.call_tool( + "act", + {}, + input_responses={ + ra_key: ElicitResult(action="accept", content={"username": "octocat"}), + "ghost": ElicitResult(action="accept", content={"username": "spooky"}), + }, + request_state=json.dumps(spliced), + allow_input_required=True, + ) + assert isinstance(second, InputRequiredResult) + assert second.input_requests is not None + (rb_key,) = second.input_requests + outcomes = _decode_state(second.request_state) + assert ra_key in outcomes + assert "ghost" not in outcomes # the spliced entry is dropped, not carried onward + + final = await client.session.call_tool( + "act", + {}, + input_responses={rb_key: ElicitResult(action="accept", content={"ok": True})}, + request_state=second.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "octocat:True" + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "forged_data", + [ + pytest.param("not-a-dict", id="non-dict-data"), + pytest.param({"hacked": True}, id="dict-failing-schema"), + ], +) +async def test_forged_state_entry_failing_the_schema_is_reasked_not_an_error(forged_data: str | dict[str, bool]): + # `request_state` is client-trusted JSON: an accept entry whose data does not + # validate against the resolver's schema reads as no recorded progress, so the + # question is asked again (not an error) and a proper answer completes the call. + mcp = MCPServer(name="ForgedState") + + async def ask(ctx: Context) -> Elicit[Login]: + return Elicit("user?", Login) + + @mcp.tool() + async def whoami(login: Annotated[Login, Resolve(ask)]) -> str: + return login.username + + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("whoami", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + assert first.request_state is not None + (key,) = first.input_requests + + forged = json.loads(first.request_state) + forged["outcomes"][key] = {"action": "accept", "data": forged_data} + second = await client.session.call_tool( + "whoami", {}, request_state=json.dumps(forged), allow_input_required=True + ) + assert isinstance(second, InputRequiredResult) # re-asked, not an error + assert second.input_requests is not None + assert set(second.input_requests) == {key} + assert _decode_state(second.request_state) == {} # the forged entry is dropped + + final = await client.session.call_tool( + "whoami", + {}, + input_responses={key: ElicitResult(action="accept", content={"username": "octocat"})}, + request_state=second.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "octocat" + + +@pytest.mark.anyio +async def test_schema_mismatched_fresh_answer_fails_the_call_without_pydantic_leakage(): + # An accepted answer whose content fails the requested schema fails the call with + # the resolver's own message; pydantic's error text (which carries a + # "For further information" link) must not leak to the client. + mcp = MCPServer(name="MismatchedAnswer") + + async def ask(ctx: Context) -> Elicit[Login]: + return Elicit("user?", Login) + + @mcp.tool() + async def whoami(login: Annotated[Login, Resolve(ask)]) -> str: + raise NotImplementedError # pragma: no cover - the mismatched answer never reaches the body + + async with Client(mcp, elicitation_callback=_accept({"nope": "x"})) as client: + result = await client.call_tool("whoami", {}) + assert result.is_error + assert isinstance(result.content[0], TextContent) + assert "does not match the requested schema" in result.content[0].text + assert "For further information" not in result.content[0].text + + +@pytest.mark.anyio +async def test_auto_driver_gives_up_when_the_chain_outlasts_its_round_budget(): + # A dependency chain of 11 eliciting resolvers needs 11 retry rounds, one more + # than the default `input_required_max_rounds`, so `client.call_tool` must raise + # rather than loop on. The pure `count_leg` resolver is never persisted, so it + # re-runs on every server leg: its final value is the exact number of legs. + mcp = MCPServer(name="TooDeep") + legs = 0 + + async def count_leg(ctx: Context) -> int: + nonlocal legs + legs += 1 + return legs + + async def root(ctx: Context) -> Elicit[Confirm]: + return Elicit("Q1?", Confirm) + + def extend(dep: Callable[..., Any], n: int) -> Callable[..., Any]: + async def link(prev: Annotated[Confirm, Resolve(dep)]) -> Elicit[Confirm]: + return Elicit(f"Q{n}?", Confirm) + + return link + + chain: Callable[..., Any] = root + for n in range(2, 12): # 11 eliciting resolvers in total + chain = extend(chain, n) + + @mcp.tool() + async def long_haul( + leg: Annotated[int, Resolve(count_leg)], + last: Annotated[Confirm, Resolve(chain)], + ) -> str: + raise NotImplementedError # pragma: no cover - the driver gives up first + + answered = 0 + + async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + nonlocal answered + answered += 1 + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp, elicitation_callback=callback) as client: + with anyio.fail_after(5): # the loop must end by raising, not spin on retries + with pytest.raises(InputRequiredRoundsExceededError) as exc_info: + await client.call_tool("long_haul", {}) + assert exc_info.value.max_rounds == client.input_required_max_rounds + assert answered == client.input_required_max_rounds # one question answered per retry round + assert legs == client.input_required_max_rounds + 1 # the initial call plus one leg per retry + + +@pytest.mark.anyio +async def test_aliased_elicitation_model_round_trips_through_request_state(): + # `_encode_state` must dump accepted models by alias: restore re-validates + # against the alias-keyed shape the client answered with (the rendered + # elicitation schema is alias-keyed). A field-name dump would fail validation + # on the round after next, drop the stored answer, and re-ask the user forever. + mcp = MCPServer(name="AliasState") + + async def who(ctx: Context) -> Elicit[Handle]: + return Elicit("handle?", Handle) + + async def confirm(h: Annotated[Handle, Resolve(who)]) -> Elicit[Confirm]: + return Elicit(f"go as {h.user_name}?", Confirm) + + @mcp.tool() + async def act( + h: Annotated[Handle, Resolve(who)], + c: Annotated[Confirm, Resolve(confirm)], + ) -> str: + return f"{h.user_name}:{c.ok}" + + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("act", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + (who_key,) = first.input_requests + + second = await client.session.call_tool( + "act", + {}, + input_responses={who_key: ElicitResult(action="accept", content={"userName": "octocat"})}, + request_state=first.request_state, + allow_input_required=True, + ) + assert isinstance(second, InputRequiredResult) + assert second.input_requests is not None + (confirm_key,) = second.input_requests # only the dependent question; the stored answer holds + assert confirm_key != who_key + + final = await client.session.call_tool( + "act", + {}, + input_responses={confirm_key: ElicitResult(action="accept", content={"ok": True})}, + request_state=second.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "octocat:True" From 18217039c6ba99905d0459efabbc767536d9f7ab Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:46:12 +0000 Subject: [PATCH 07/10] Run refund_desk on the dual era The manifest pinned the story to era="legacy" until resolver elicitation could ride the 2026-07-28 input_required round-trip; that is now the case, so the story runs the full transport x era matrix with identical author code. README caveats describe the per-era transport and what persists across rounds; the spec links point at the real anchor (#input-required-tool-results), fixing the same stale fragment in the mrtr story; legacy_elicitation's cross-reference no longer claims resolver DI rides push elicitation everywhere. --- examples/stories/legacy_elicitation/README.md | 4 +- examples/stories/manifest.toml | 5 +-- examples/stories/mrtr/README.md | 2 +- examples/stories/refund_desk/README.md | 40 +++++++++++++------ examples/stories/refund_desk/client.py | 6 ++- 5 files changed, 37 insertions(+), 20 deletions(-) diff --git a/examples/stories/legacy_elicitation/README.md b/examples/stories/legacy_elicitation/README.md index 1a9d48e60..e9812aced 100644 --- a/examples/stories/legacy_elicitation/README.md +++ b/examples/stories/legacy_elicitation/README.md @@ -68,6 +68,6 @@ uv run python -m stories.legacy_elicitation.client --http --legacy --server serv ## See also `sampling/` (same push-request shape, deprecated per SEP-2577), `mrtr/` -(planned — the 2026-era carrier), `error_handling/` +(the 2026-era carrier), `error_handling/` (`UrlElicitationRequiredError`), `refund_desk/` (resolver DI rides this push -mechanism today). +mechanism on handshake-era connections). diff --git a/examples/stories/manifest.toml b/examples/stories/manifest.toml index 57ec0e8a4..1ba2fe862 100644 --- a/examples/stories/manifest.toml +++ b/examples/stories/manifest.toml @@ -40,9 +40,8 @@ era = "legacy" status = "legacy" [story.refund_desk] -# Resolver DI rides push elicitation (ctx.elicit) today; era flips to "dual" once -# the SDK carries resolver elicitation over the 2026 input_required round-trip. -era = "legacy" +# Resolver elicitation picks its transport per era: input_required round-trips on +# the modern leg, push elicitation (ctx.elicit) on the legacy one. lowlevel = false [story.sampling] diff --git a/examples/stories/mrtr/README.md b/examples/stories/mrtr/README.md index de214988d..aaad86ca9 100644 --- a/examples/stories/mrtr/README.md +++ b/examples/stories/mrtr/README.md @@ -46,7 +46,7 @@ uv run python -m stories.mrtr.client --http --server server_lowlevel ## Spec -[Multi-round results — server features](https://modelcontextprotocol.io/specification/draft/server/tools#multi-round-results) +[Input required tool results — server features](https://modelcontextprotocol.io/specification/draft/server/tools#input-required-tool-results) ## See also diff --git a/examples/stories/refund_desk/README.md b/examples/stories/refund_desk/README.md index 0a77dd580..664995e13 100644 --- a/examples/stories/refund_desk/README.md +++ b/examples/stories/refund_desk/README.md @@ -7,9 +7,10 @@ reason)` refunds what the order record says — `cents` is resolver-computed and does not appear in the input schema at all, so the model cannot supply or inflate the amount. Resolvers form a DAG (`load_order` → `refund_scope` → `refund_amount` / `ask_restock`), may return `Elicit[...]` to ask the human, -and run at most once per call. A resolver's own plain parameters are filled -from the tool's arguments by name — `load_order(order_id)` receives the -`order_id` the model passed to `refund_order`. +and ask each question at most once per call. A resolver's own plain +parameters are filled from the tool's arguments by name — +`load_order(order_id)` receives the `order_id` the model passed to +`refund_order`. ## Run it @@ -18,9 +19,9 @@ from the tool's arguments by name — `load_order(order_id)` receives the uv run python -m stories.refund_desk.client # HTTP — the client self-hosts the server on a free port, runs, then tears it -# down (--legacy: resolver elicitation rides the push request today; the -# manifest pins this era, so bare --http runs the same leg) -uv run python -m stories.refund_desk.client --http --legacy +# down (2026 protocol: the questions ride embedded input_required round-trips; +# add --legacy to ride synchronous push elicitation instead) +uv run python -m stories.refund_desk.client --http ``` ## What to look at @@ -47,21 +48,36 @@ uv run python -m stories.refund_desk.client --http --legacy ## Caveats +- **Transport per era.** The framework picks the elicitation transport from + the negotiated protocol: at >= 2026-07-28 the questions ride embedded + `input_required` round-trips (a resolver that depends on another's answer is + asked in a later round); at <= 2025-11-25 each is a synchronous + `elicitation/create` push request mid-call. Author code is identical on + both — this client runs unchanged on either era. - **Decline order.** A declined unwrapped dependency aborts resolution in tool-signature order — `cents` resolves before `restock`, so `ask_restock` never runs. Don't rely on a later resolver's side effects after an earlier consumer can abort. -- **Memoization scope.** Each resolver runs at most once per `tools/call`, - keyed by function identity; nothing is cached across calls or connections. +- **Memoization scope.** Each question is asked at most once per call, and + within a round each resolver runs at most once, keyed by function identity. + Across 2026 rounds only *elicited* outcomes persist (in `requestState`); a + resolver that resolves without eliciting is pure and may re-run each round. + An answer is matched back to its question when the call resumes, so an + eliciting resolver must derive its question deterministically from the + tool's arguments and earlier answers; a per-call generated value (a + `default_factory` id, a timestamp) is re-derived each round and must not + appear in a question the answer is meant to bind to. Nothing is cached + across calls or connections. - **Validate elicited values.** Elicited answers are human-typed; check them against your records (as `_scoped` does) before acting on them. ## Spec -[Elicitation — client features](https://modelcontextprotocol.io/specification/2025-11-25/client/elicitation) +[Elicitation — client features](https://modelcontextprotocol.io/specification/2025-11-25/client/elicitation), +[Input required tool results — server features](https://modelcontextprotocol.io/specification/draft/server/tools#input-required-tool-results) ## See also -`legacy_elicitation/` (the push mechanism resolver elicitation rides on today), -`mrtr/` (the 2026 `input_required` carrier; resolver DI will ride it once the -SDK wires them together). +`mrtr/` (the 2026 `input_required` carrier these questions ride at +>= 2026-07-28), `legacy_elicitation/` (the push mechanism they ride on +handshake-era connections). diff --git a/examples/stories/refund_desk/client.py b/examples/stories/refund_desk/client.py index ee86d94b4..0ff8d28fc 100644 --- a/examples/stories/refund_desk/client.py +++ b/examples/stories/refund_desk/client.py @@ -41,7 +41,9 @@ async def on_elicit(context: ClientRequestContext, params: types.ElicitRequestPa assert counts == {"scope": 0, "restock": 0}, counts # Full refund of a three-line order. The scope question fires exactly ONCE even though - # both refund_amount and ask_restock consume it — memoized within the call. + # both refund_amount and ask_restock consume it — asked at most once per call on either + # era. ask_restock needs the scope ANSWER, so at 2026 the two questions land in + # successive rounds, never one concurrent batch: counts and order are era-independent. receipt = await client.call_tool("refund_order", {"order_id": "ORD-7002", "reason": "arrived broken"}) assert receipt.structured_content == { "order_id": "ORD-7002", @@ -53,7 +55,7 @@ async def on_elicit(context: ClientRequestContext, params: types.ElicitRequestPa # Declining restock still refunds: the tool keeps the ElicitationResult union for # `restock`, sees the decline, and just skips the restock. The scope counter moves - # again — the memo cache is per tools/call, not per connection. + # again — questions are deduped per call, not per connection. declines.add("restock") answers["scope"] = {"full": False, "sku": "canvas-tote"} receipt = await client.call_tool("refund_order", {"order_id": "ORD-7002", "reason": "wrong colour"}) From 85756a79e01259f33b42d43f0b1abcfdddb2944a Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:46:25 +0000 Subject: [PATCH 08/10] Reconcile docs with the negotiated elicitation transport The Dependencies tutorial explains that the framework picks the question's transport from the negotiated protocol version, scopes the run-once claim to what is guaranteed (each question asked once; a non-eliciting resolver may re-run when a call resumes), and documents the determinism constraint on questions. Multi-round-trip requests no longer claims @mcp.tool() has no high-level path to InputRequiredResult - resolver dependencies are that path. The eliciting tutorial snippets are tested on both transports. --- docs/advanced/multi-round-trip.md | 4 ++-- docs/tutorial/dependencies.md | 15 ++++++++++++++- tests/docs_src/test_dependencies.py | 23 +++++++++++++++++------ 3 files changed, 33 insertions(+), 9 deletions(-) diff --git a/docs/advanced/multi-round-trip.md b/docs/advanced/multi-round-trip.md index de11a8db8..665808a5d 100644 --- a/docs/advanced/multi-round-trip.md +++ b/docs/advanced/multi-round-trip.md @@ -19,7 +19,7 @@ That's the whole protocol. Every leg is an ordinary request from the client to t ## The server side -The high-level `@mcp.tool()` decorator has no sugar for this yet. Today you write it on the **low-level** `Server`, whose `on_call_tool` handler is allowed to return either result type: +On `@mcp.tool()` you rarely build this by hand: declare a dependency that asks the user and the SDK returns the `InputRequiredResult` for you - that form is the **[Dependencies](../tutorial/dependencies.md)** tutorial. The manual form is the **low-level** `Server`, whose `on_call_tool` handler is allowed to return either result type: ```python title="server.py" hl_lines="44-47" --8<-- "docs_src/mrtr/tutorial001.py" @@ -93,6 +93,6 @@ Drop to the underlying session, where `allow_input_required=True` hands you the * `input_requests` is what it needs. `request_state` is an opaque resume token only the server reads. * `Client` runs the retry loop for you: register `elicitation_callback` / `sampling_callback` / `list_roots_callback` and `call_tool` returns a plain `CallToolResult`. `input_required_max_rounds` (default 10) bounds it. * To inspect or persist rounds, use `client.session.call_tool(..., allow_input_required=True)` and own the `while isinstance(result, InputRequiredResult)` loop yourself. -* The server side is the **low-level** `Server` only; `@mcp.tool()` has no sugar for this yet. +* On `@mcp.tool()`, a dependency that asks the user produces this result for you (**[Dependencies](../tutorial/dependencies.md)**); the **low-level** `Server` is the manual form. This is the mechanism that replaces server-initiated sampling and the rest of the push-style back-channel; see **[Deprecated features](deprecated.md)**. diff --git a/docs/tutorial/dependencies.md b/docs/tutorial/dependencies.md index e9b4c789b..47b75f0b6 100644 --- a/docs/tutorial/dependencies.md +++ b/docs/tutorial/dependencies.md @@ -116,11 +116,24 @@ And if the user won't answer at all - declines the question, or cancels it? That's the right default for a precondition: no answer, no order. When declining is an outcome your tool wants to handle - skip the backorder but still suggest another title - annotate `ElicitationResult[Backorder]` instead and the tool receives the full accept/decline/cancel outcome to branch on. **[Elicitation](elicitation.md)** shows that form, and everything else about asking: the schema rules, the three answers, the client's side of the conversation. +!!! info + The framework picks the question's transport from the negotiated protocol version; the code + above is identical on both. On **2026-07-28** and later the question rides inside a + multi-round-trip `tools/call` - the server returns it, the client's `elicitation_callback` + answers it, and the `Client` retries the call for you (**[Multi-round-trip requests](../advanced/multi-round-trip.md)**). On + **2025-11-25** and earlier it is a synchronous elicitation request mid-call. Each question is + asked exactly once per call; a resolver that answered *without* asking, like `check_stock`, + may run again when the call resumes after a question. When it resumes, each answer is matched + back to its question, so an eliciting resolver must derive its question deterministically from + the tool's arguments and earlier answers - a per-call generated value (a `default_factory` id, + a timestamp) is re-derived on each round and must not appear in a question the answer is meant + to bind to. + ## Recap * `Annotated[T, Resolve(fn)]` on a tool parameter: the SDK runs `fn` and injects its return value. * A resolved parameter is invisible to the model and cannot be supplied by a client. Values the model must not invent - prices, identities, permissions - belong here. -* A resolver's parameters are resolved the same way: the `Context`, another `Resolve(...)`, or a tool argument by name. The graph runs each resolver at most once per call. +* A resolver's parameters are resolved the same way: the `Context`, another `Resolve(...)`, or a tool argument by name. The graph runs each resolver at most once, however many consumers it has; a resolver that never asked may run again when a call resumes after a question. * Bad graphs fail at registration with `InvalidSignature`, not mid-call. * Return `Elicit(message, Model)` to ask the user, only when you have to. Unwrapped annotations abort on decline; `ElicitationResult[T]` lets the tool branch. diff --git a/tests/docs_src/test_dependencies.py b/tests/docs_src/test_dependencies.py index 73355a892..06d893585 100644 --- a/tests/docs_src/test_dependencies.py +++ b/tests/docs_src/test_dependencies.py @@ -1,5 +1,7 @@ """`docs/tutorial/dependencies.md`: every claim the page makes, proved against the real SDK.""" +from typing import Literal + import pytest from inline_snapshot import snapshot from mcp_types import ElicitRequestParams, ElicitResult, TextContent @@ -79,18 +81,24 @@ def get(self, key: str, default: int) -> int: assert inventory.lookups == ["Dune", "Dune"] -async def test_an_in_stock_order_asks_no_question() -> None: +# The `!!! info` claims the tutorial003 behaviour is transport-independent, so each claim is +# proved on both: mode="legacy" elicits synchronously mid-call (2025-11-25 and earlier), while +# mode="auto" negotiates 2026-07-28, where the question rides a multi-round-trip `tools/call` +# and `Client` drives the retries. +@pytest.mark.parametrize("mode", ["legacy", "auto"]) +async def test_an_in_stock_order_asks_no_question(mode: Literal["legacy", "auto"]) -> None: """tutorial003: `confirm_backorder` returns directly when stock exists - no round-trip.""" async def never(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: # pragma: no cover raise AssertionError("an in-stock order must not elicit") - async with Client(tutorial003.mcp, mode="legacy", elicitation_callback=never) as client: + async with Client(tutorial003.mcp, mode=mode, elicitation_callback=never) as client: result = await client.call_tool("order_book", {"title": "Dune"}) assert result.content == [TextContent(type="text", text="Ordered 'Dune'.")] +@pytest.mark.parametrize("mode", ["legacy", "auto"]) @pytest.mark.parametrize( ("confirm", "expected"), [ @@ -98,7 +106,9 @@ async def never(context: ClientRequestContext, params: ElicitRequestParams) -> E (False, "No order placed."), ], ) -async def test_an_out_of_stock_order_asks_and_honours_the_answer(confirm: bool, expected: str) -> None: +async def test_an_out_of_stock_order_asks_and_honours_the_answer( + mode: Literal["legacy", "auto"], confirm: bool, expected: str +) -> None: """tutorial003: the resolver elicits, the SDK validates the answer, the tool reads it.""" asked: list[str] = [] @@ -106,20 +116,21 @@ async def on_elicit(context: ClientRequestContext, params: ElicitRequestParams) asked.append(params.message) return ElicitResult(action="accept", content={"confirm": confirm}) - async with Client(tutorial003.mcp, mode="legacy", elicitation_callback=on_elicit) as client: + async with Client(tutorial003.mcp, mode=mode, elicitation_callback=on_elicit) as client: result = await client.call_tool("order_book", {"title": "Neuromancer"}) assert result.content == [TextContent(type="text", text=expected)] assert asked == ["'Neuromancer' is out of stock (2-3 weeks). Order anyway?"] -async def test_declining_an_unwrapped_dependency_aborts_the_call() -> None: +@pytest.mark.parametrize("mode", ["legacy", "auto"]) +async def test_declining_an_unwrapped_dependency_aborts_the_call(mode: Literal["legacy", "auto"]) -> None: """tutorial003: no answer, no order - the error text on the page is the real one.""" async def decline(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: return ElicitResult(action="decline") - async with Client(tutorial003.mcp, mode="legacy", elicitation_callback=decline) as client: + async with Client(tutorial003.mcp, mode=mode, elicitation_callback=decline) as client: result = await client.call_tool("order_book", {"title": "Neuromancer"}) assert result.is_error From ef53c4d320e7eedaf637dc7ed6c4dc2998bb8422 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 29 Jun 2026 13:24:51 +0000 Subject: [PATCH 09/10] Persist elicited answers as raw wire content; reject ambiguous Elicit returns request_state entries now store exactly the content the client sent (which already passed validation) instead of re-deriving it from the validated model with model_dump - persistence and restore previously used two codecs that diverge whenever validation and serialization aliases differ, dropping a legitimate stored answer on the third round and re-asking forever. Encode and restore are now the identity on the client's bytes, so no serialization knob can break the round-trip; the earlier by-alias dump is gone with the rest of the model-dumping path. A return annotation with more than one distinct Elicit arm now raises InvalidSignature at registration: the restore path validates against the single statically-known schema, so `-> Elicit[A] | Elicit[B]` either never converges (the stored B answer fails A validation each round) or silently injects a wrong-typed model when the shapes happen to overlap. Like cycles and unclassifiable parameters, the ambiguous shape is rejected up front. Legacy parity: elicit_with_validation wraps a schema-mismatched accepted answer in the same stable error the 2026 path uses, instead of letting the raw pydantic ValidationError text reach the client; noted in the migration guide since callers catching ValidationError see ValueError now. --- docs/migration.md | 2 + docs/tutorial/elicitation.md | 4 +- src/mcp/server/elicitation.py | 11 ++- src/mcp/server/mcpserver/resolve.py | 74 ++++++++------ tests/docs_src/test_elicitation.py | 2 +- tests/server/mcpserver/test_resolve.py | 130 ++++++++++++++++++++----- 6 files changed, 163 insertions(+), 60 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index 79c15d91f..fd76d8a4f 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -786,6 +786,8 @@ Positional calls (`await ctx.info("hello")`) are unaffected. `Context.elicit()` (and `elicit_with_validation()`) now render the schema first and validate each property against the spec's `PrimitiveSchemaDefinition`, raising `TypeError` at the call site for anything outside it. `Optional[T]` fields render as `{"type": ...}` with the field omitted from `required` (previously the non-spec `anyOf` shape). A bare `list[str]` field is rejected because it renders without the required enum items; use `list[Literal[...]]` or `list[str]` with `json_schema_extra` supplying the items. Unions of multiple primitives (e.g. `int | str`) and nested models are rejected. +A schema-mismatched *accepted* answer also fails differently: the call now raises `ValueError` with a stable message ("Received an accepted elicitation whose content does not match the requested schema") instead of letting pydantic's `ValidationError` escape with its internals. Code that caught `ValidationError` around `ctx.elicit()` should catch `ValueError` (or rely on the tool's error result). + ### Replace `RootModel` by union types with `TypeAdapter` validation The following union types are no longer `RootModel` subclasses: diff --git a/docs/tutorial/elicitation.md b/docs/tutorial/elicitation.md index aa4f16820..7bd27a78a 100644 --- a/docs/tutorial/elicitation.md +++ b/docs/tutorial/elicitation.md @@ -76,8 +76,8 @@ A refusal is not an error. The tool decides what declining means (here, no booki !!! tip The answer is validated against your model before your code sees it. A client that sends - `"maybe"` for a `bool` doesn't corrupt your booking: the call fails with the - `ValidationError`, your `if` never runs. + `"maybe"` for a `bool` doesn't corrupt your booking: the call fails with a + schema-mismatch error, your `if` never runs. ## Ask before the tool runs diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index e730c7bfb..5a4acdd6c 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -116,7 +116,8 @@ async def elicit_with_validation( For sensitive data like credentials or OAuth flows, use elicit_url() instead. Raises: - ValueError: If the client accepted the elicitation without supplying content. + ValueError: If the client accepted the elicitation without supplying + content, or with content that does not match the requested schema. """ json_schema = render_elicitation_schema(schema) @@ -129,8 +130,12 @@ async def elicit_with_validation( if result.action == "accept": if result.content is None: raise ValueError("Received an accepted elicitation with no content") - # Validate and parse the content using the schema - validated_data = schema.model_validate(result.content) + try: + validated_data = schema.model_validate(result.content) + except ValidationError as e: + raise ValueError( + "Received an accepted elicitation whose content does not match the requested schema" + ) from e return AcceptedElicitation(data=validated_data) if result.action == "decline": return DeclinedElicitation() diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py index e033c5c3c..323ce5cdd 100644 --- a/src/mcp/server/mcpserver/resolve.py +++ b/src/mcp/server/mcpserver/resolve.py @@ -183,21 +183,30 @@ def _contains_resolve(annotation: Any) -> bool: return any(_contains_resolve(arg) for arg in get_args(annotation)) -def _elicit_return_schema(return_annotation: Any) -> type[BaseModel] | None: +def _elicit_return_schema(return_annotation: Any, name: str) -> type[BaseModel] | None: """Extract `T` from a resolver return type's `Elicit[T]` arm, if present. Handles a bare `-> Elicit[T]` and a `-> T | Elicit[T]` union. Lets an elicited outcome restored from `request_state` (a plain dict) be re-validated into its model so dependent resolvers and tools receive a typed value. + + Raises: + InvalidSignature: If the annotation has more than one `Elicit[...]` arm; + the runtime can honor only one static question schema per resolver. """ # A bare `Elicit[T]` is itself a candidate; a union contributes its members. candidates = get_args(return_annotation) if _is_union(return_annotation) else (return_annotation,) - for candidate in candidates: - if get_origin(candidate) is Elicit: - schema = get_args(candidate)[0] - if isinstance(schema, type) and issubclass(schema, BaseModel): - return schema - return None + # Typing dedupes equal union members, so two arms here are genuinely distinct. + arms = [c for c in candidates if get_origin(c) is Elicit] + if len(arms) > 1: + raise InvalidSignature( + f"Resolver {name!r} return annotation has multiple Elicit arms; " + "a resolver asks one question - split it into separate resolvers" + ) + if not arms: + return None + schema = get_args(arms[0])[0] + return schema if isinstance(schema, type) and issubclass(schema, BaseModel) else None def _is_union(annotation: Any) -> bool: @@ -290,9 +299,8 @@ def analyze(fn: Callable[..., Any], stack: tuple[Hashable, ...]) -> None: "expected a Context, an Annotated[_, Resolve(...)], or a tool argument by name" ) - plans[key] = _ResolverPlan( - fn, params, is_async_callable(fn), _elicit_return_schema(hints.get("return")), wire_key - ) + elicit_schema = _elicit_return_schema(hints.get("return"), _resolver_name(fn)) + plans[key] = _ResolverPlan(fn, params, is_async_callable(fn), elicit_schema, wire_key) for dep in nested: analyze(dep, stack + (key,)) @@ -342,11 +350,13 @@ def __init__( self.answers: InputResponses = context.input_responses or {} if input_required else {} self.state = _decode_state(context.request_state) if input_required else {} # In-call dedup keyed by resolver identity (distinguishes two instances of - # the same bound method); `elicited` holds only outcomes that came from an - # elicitation, keyed by their wire key - these are what `request_state` - # persists, since pure resolvers are cheap to re-run each round. + # the same bound method); `persist` holds the wire-shaped record of each + # elicited outcome, keyed by its wire key - exactly what the next round's + # `request_state` carries. Entries are the client's own (validated) wire + # data, never re-derived from a model, so encode-restore is the identity. + # Pure resolvers are cheap to re-run each round and are not persisted. self.cache: dict[Hashable, ElicitationResult[Any]] = {} - self.elicited: dict[str, ElicitationResult[Any]] = {} + self.persist: dict[str, _StateEntry] = {} self.pending: InputRequests = {} @@ -399,7 +409,7 @@ async def resolve_arguments( injected[name] = outcome if wants_union else _unwrap(outcome, name) if res.pending: - return InputRequiredResult(input_requests=res.pending, request_state=_encode_state(res.elicited)) + return InputRequiredResult(input_requests=res.pending, request_state=_encode_state(res.persist)) return injected @@ -456,7 +466,6 @@ async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResul if _is_elicit(result): outcome = await _elicit(result, wire_key, res) - res.elicited[wire_key] = outcome else: # A resolver may return any type (not just `BaseModel`), so accept it as the # outcome without validating against the schema bound. Plain outcomes are not @@ -496,9 +505,14 @@ async def _elicit(elicit: Elicit[Any], key: str, res: _Resolution) -> Elicitatio raise ToolError( f"Resolver {key!r} received an accepted elicitation whose content does not match the requested schema" ) from e + # Persist the exact wire content that just passed validation - never the + # model - so restoring next round revalidates the same bytes the client sent. + res.persist[key] = _StateEntry(action="accept", data=answer.content) return AcceptedElicitation(data=data) if answer.action == "decline": + res.persist[key] = _StateEntry(action="decline") return DeclinedElicitation() + res.persist[key] = _StateEntry(action="cancel") return CancelledElicitation() @@ -591,18 +605,13 @@ def _decode_state(request_state: str | None) -> dict[str, _StateEntry]: return state.outcomes if state.v == _STATE_VERSION else {} -def _encode_state(outcomes: Mapping[str, ElicitationResult[Any]]) -> str: - """Encode resolved outcomes (keyed by resolver path) for the next round.""" - entries: dict[str, _StateEntry] = {} - for path, outcome in outcomes.items(): - data = outcome.data if isinstance(outcome, AcceptedElicitation) else None - if isinstance(data, BaseModel): - # By alias: the stored shape must round-trip through - # `schema.model_validate` on restore, which expects the alias-keyed - # form the client answered with (the rendered schema is alias-keyed). - data = data.model_dump(mode="json", by_alias=True) - entries[path] = _StateEntry(action=outcome.action, data=data) - return _State(v=_STATE_VERSION, outcomes=entries).model_dump_json() +def _encode_state(outcomes: Mapping[str, _StateEntry]) -> str: + """Encode recorded elicitation outcomes (keyed by wire key) for the next round. + + Entries already hold the client's wire-shaped data exactly as it was sent (and + validated), so encoding is pure wrapping: encode-restore is the identity. + """ + return _State(v=_STATE_VERSION, outcomes=dict(outcomes)).model_dump_json() def _outcome_from_state(entry: _StateEntry, schema: type[BaseModel] | None) -> ElicitationResult[Any]: @@ -629,9 +638,10 @@ def _restore_outcome(res: _Resolution, key: str, schema: type[BaseModel] | None) the `_decode_state` treatment - dropped as if no progress was recorded, so the question is asked again - rather than surfacing a validation error. - Carries a restored outcome forward in `res.elicited`: if a later resolver is - still pending, the next round's `request_state` is built from `res.elicited`, - so an earlier answer must stay there or it would be dropped and re-asked. + Carries the original decoded entry forward unchanged in `res.persist`: if a + later resolver is still pending, the next round's `request_state` is built from + `res.persist`, so an earlier answer must stay there - byte-identical, never + re-derived - or it would be dropped and re-asked. """ entry = res.state.get(key) if entry is None: @@ -641,7 +651,7 @@ def _restore_outcome(res: _Resolution, key: str, schema: type[BaseModel] | None) except ValidationError: del res.state[key] return None - res.elicited[key] = outcome + res.persist[key] = entry return outcome diff --git a/tests/docs_src/test_elicitation.py b/tests/docs_src/test_elicitation.py index 4c9bb4036..a28f1087f 100644 --- a/tests/docs_src/test_elicitation.py +++ b/tests/docs_src/test_elicitation.py @@ -124,7 +124,7 @@ async def on_elicit(context: ClientRequestContext, params: ElicitRequestParams) result = await client.call_tool("book_table", {"date": "2025-12-25", "party_size": 2}) assert result.is_error assert isinstance(result.content[0], TextContent) - assert "Input should be a valid boolean" in result.content[0].text + assert "does not match the requested schema" in result.content[0].text class Address(BaseModel): diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py index 4bef8463f..7e92f1c4e 100644 --- a/tests/server/mcpserver/test_resolve.py +++ b/tests/server/mcpserver/test_resolve.py @@ -40,6 +40,7 @@ _outcome_from_state, _resolver_key, _state_key, + _StateEntry, _uses_input_required, find_resolved_parameters, ) @@ -59,6 +60,10 @@ class Handle(BaseModel): user_name: str = Field(alias="userName") +class Account(BaseModel): + user_name: str = Field(validation_alias="vUser", serialization_alias="sUser") + + async def _alias_login(ctx: Context) -> Login: return Login(username="x") # pragma: no cover - only the signature is inspected @@ -337,6 +342,20 @@ async def tool(login: Annotated[Login, Resolve(login)]) -> str: Tool.from_function(tool) +def test_multiple_elicit_arms_raise_at_registration(): + # The runtime can honor only one static question schema per resolver, so an + # ambiguous `-> Elicit[A] | Elicit[B]` must not register (the second arm used + # to be silently ignored). + async def ambiguous(ctx: Context) -> Elicit[Login] | Elicit[Confirm]: + raise NotImplementedError # pragma: no cover + + async def tool(login: Annotated[Login, Resolve(ambiguous)]) -> str: + return login.username # pragma: no cover + + with pytest.raises(InvalidSignature, match="multiple Elicit arms"): + Tool.from_function(tool) + + def test_resolve_marker_inside_a_union_raises_at_registration(): async def login(ctx: Context) -> Login: return Login(username="x") # pragma: no cover @@ -833,13 +852,14 @@ def test_decode_state_tolerates_malformed_request_state(request_state: str | Non def test_state_round_trips_accept_decline_cancel(): - outcomes: dict[str, ElicitationResult[BaseModel]] = { - "a": AcceptedElicitation(data=Login(username="octocat")), - "b": DeclinedElicitation(), - "c": CancelledElicitation(), - "d": AcceptedElicitation.model_construct(data="raw-token"), # non-model value + entries = { + "a": _StateEntry(action="accept", data={"username": "octocat"}), + "b": _StateEntry(action="decline"), + "c": _StateEntry(action="cancel"), + "d": _StateEntry(action="accept", data="raw-token"), # non-dict wire value } - decoded = _decode_state(_encode_state(outcomes)) + decoded = _decode_state(_encode_state(entries)) + assert decoded == entries # encode-restore is the identity on the stored entries accepted = _outcome_from_state(decoded["a"], Login) assert isinstance(accepted, AcceptedElicitation) and accepted.data == Login(username="octocat") @@ -850,14 +870,17 @@ def test_state_round_trips_accept_decline_cancel(): def test_elicit_return_schema_extraction(): - assert _elicit_return_schema(Elicit[Login]) is Login # bare Elicit[T] - assert _elicit_return_schema(Login | Elicit[Login]) is Login # union arm - assert _elicit_return_schema(Login) is None # no Elicit arm - assert _elicit_return_schema(None) is None + assert _elicit_return_schema(Elicit[Login], "r") is Login # bare Elicit[T] + assert _elicit_return_schema(Login | Elicit[Login], "r") is Login # union arm + assert _elicit_return_schema(Login, "r") is None # no Elicit arm + assert _elicit_return_schema(None, "r") is None # The bound on `Elicit`'s parameter is unenforced at runtime, so a non-model # subscription is constructible and must yield no schema rather than crash. unbounded_elicit: Any = Elicit - assert _elicit_return_schema(unbounded_elicit[int]) is None + assert _elicit_return_schema(unbounded_elicit[int], "r") is None + # Two distinct Elicit arms are ambiguous: the runtime can honor only one schema. + with pytest.raises(InvalidSignature, match="'r' return annotation has multiple Elicit arms"): + _elicit_return_schema(Elicit[Login] | Elicit[Confirm], "r") @pytest.mark.anyio @@ -1385,10 +1408,11 @@ async def whoami(login: Annotated[Login, Resolve(ask)]) -> str: @pytest.mark.anyio -async def test_schema_mismatched_fresh_answer_fails_the_call_without_pydantic_leakage(): - # An accepted answer whose content fails the requested schema fails the call with - # the resolver's own message; pydantic's error text (which carries a - # "For further information" link) must not leak to the client. +@pytest.mark.parametrize("mode", ["legacy", "auto"]) +async def test_schema_mismatched_fresh_answer_fails_the_call_without_pydantic_leakage(mode: Literal["legacy", "auto"]): + # An accepted answer whose content fails the requested schema fails the call + # with the framework's own message on both transports; pydantic's error text + # (which carries an "errors.pydantic.dev" link) must not leak to the client. mcp = MCPServer(name="MismatchedAnswer") async def ask(ctx: Context) -> Elicit[Login]: @@ -1398,12 +1422,17 @@ async def ask(ctx: Context) -> Elicit[Login]: async def whoami(login: Annotated[Login, Resolve(ask)]) -> str: raise NotImplementedError # pragma: no cover - the mismatched answer never reaches the body - async with Client(mcp, elicitation_callback=_accept({"nope": "x"})) as client: + async with Client(mcp, mode=mode, elicitation_callback=_accept({"nope": "x"})) as client: result = await client.call_tool("whoami", {}) assert result.is_error assert isinstance(result.content[0], TextContent) - assert "does not match the requested schema" in result.content[0].text - assert "For further information" not in result.content[0].text + text = result.content[0].text + assert "does not match the requested schema" in text + assert "errors.pydantic.dev" not in text + if mode == "auto": + assert "Resolver" in text # the input_required transport names the offending resolver key + else: + assert "Received an accepted elicitation" in text # the legacy path has no wire key to name @pytest.mark.anyio @@ -1458,10 +1487,10 @@ async def callback(context: ClientRequestContext, params: ElicitRequestParams) - @pytest.mark.anyio async def test_aliased_elicitation_model_round_trips_through_request_state(): - # `_encode_state` must dump accepted models by alias: restore re-validates - # against the alias-keyed shape the client answered with (the rendered - # elicitation schema is alias-keyed). A field-name dump would fail validation - # on the round after next, drop the stored answer, and re-ask the user forever. + # The stored entry is the client's raw wire content, so it restores through + # the same validation the answer originally passed - aliases and all. A + # re-derived (field-name) shape would fail validation on the round after + # next, drop the stored answer, and re-ask the user forever. mcp = MCPServer(name="AliasState") async def who(ctx: Context) -> Elicit[Handle]: @@ -1505,3 +1534,60 @@ async def act( assert isinstance(final, CallToolResult) assert isinstance(final.content[0], TextContent) assert final.content[0].text == "octocat:True" + + +@pytest.mark.anyio +async def test_divergent_validation_and_serialization_aliases_round_trip(): + # `request_state` must carry the client's answer exactly as it was sent: the + # rendered question is validation-aliased, so re-deriving the stored shape from + # the validated model (which serializes under the *serialization* alias) would + # produce data the schema's own validation rejects, dropping the stored answer + # on the round after next and re-asking the user. + mcp = MCPServer(name="DivergentAliases") + + async def who(ctx: Context) -> Elicit[Account]: + return Elicit("account?", Account) + + async def confirm(a: Annotated[Account, Resolve(who)]) -> Elicit[Confirm]: + return Elicit(f"go as {a.user_name}?", Confirm) + + @mcp.tool() + async def act( + a: Annotated[Account, Resolve(who)], + c: Annotated[Confirm, Resolve(confirm)], + ) -> str: + return f"{a.user_name}:{c.ok}" + + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("act", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + (who_key,) = first.input_requests + question = first.input_requests[who_key].params + assert isinstance(question, ElicitRequestFormParams) + assert "vUser" in question.requested_schema["properties"] # the client answers validation-aliased + + second = await client.session.call_tool( + "act", + {}, + input_responses={who_key: ElicitResult(action="accept", content={"vUser": "octocat"})}, + request_state=first.request_state, + allow_input_required=True, + ) + assert isinstance(second, InputRequiredResult) + assert second.input_requests is not None + (go_key,) = second.input_requests # only the dependent question; the stored answer holds + assert go_key != who_key + # The stored entry is the client's wire content, not a re-serialization of it. + assert _decode_state(second.request_state)[who_key].data == {"vUser": "octocat"} + + final = await client.session.call_tool( + "act", + {}, + input_responses={go_key: ElicitResult(action="accept", content={"ok": True})}, + request_state=second.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "octocat:True" From 1a15cb6eb009af441da88e64d9d0544b950432d2 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 29 Jun 2026 13:25:05 +0000 Subject: [PATCH 10/10] Scope the resolver run-once guarantee to questions, not resolver bodies Each question is asked exactly once per call, but on the multi-round-trip form an eliciting resolver's body runs again to consume its answer, and a resolver that answered without asking may run again whenever the call resumes. The tutorial info box, its recap bullet, and the refund_desk caveats now say exactly that, so authors don't hang side effects on a runs-at-most-once reading that only holds for the synchronous form. --- docs/tutorial/dependencies.md | 16 +++++++++------- examples/stories/refund_desk/README.md | 2 ++ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/docs/tutorial/dependencies.md b/docs/tutorial/dependencies.md index 47b75f0b6..4e91515ef 100644 --- a/docs/tutorial/dependencies.md +++ b/docs/tutorial/dependencies.md @@ -122,18 +122,20 @@ That's the right default for a precondition: no answer, no order. When declining multi-round-trip `tools/call` - the server returns it, the client's `elicitation_callback` answers it, and the `Client` retries the call for you (**[Multi-round-trip requests](../advanced/multi-round-trip.md)**). On **2025-11-25** and earlier it is a synchronous elicitation request mid-call. Each question is - asked exactly once per call; a resolver that answered *without* asking, like `check_stock`, - may run again when the call resumes after a question. When it resumes, each answer is matched - back to its question, so an eliciting resolver must derive its question deterministically from - the tool's arguments and earlier answers - a per-call generated value (a `default_factory` id, - a timestamp) is re-derived on each round and must not appear in a question the answer is meant - to bind to. + asked exactly once per call - a guarantee about the question, not the resolver. In the + multi-round-trip form an eliciting resolver runs again to consume its answer, so code before + its `return Elicit(...)` runs on the asking round and again on the answering one; a resolver + that answered *without* asking, like `check_stock`, may run again whenever the call resumes + after a question. When it resumes, each answer is matched back to its question, so an + eliciting resolver must derive its question deterministically from the tool's arguments and + earlier answers - a per-call generated value (a `default_factory` id, a timestamp) is + re-derived on each round and must not appear in a question the answer is meant to bind to. ## Recap * `Annotated[T, Resolve(fn)]` on a tool parameter: the SDK runs `fn` and injects its return value. * A resolved parameter is invisible to the model and cannot be supplied by a client. Values the model must not invent - prices, identities, permissions - belong here. -* A resolver's parameters are resolved the same way: the `Context`, another `Resolve(...)`, or a tool argument by name. The graph runs each resolver at most once, however many consumers it has; a resolver that never asked may run again when a call resumes after a question. +* A resolver's parameters are resolved the same way: the `Context`, another `Resolve(...)`, or a tool argument by name. The graph runs each resolver at most once per round, however many consumers it has; each question is asked exactly once, an eliciting resolver runs again to consume its answer, and a resolver that never asked may run again when a call resumes. * Bad graphs fail at registration with `InvalidSignature`, not mid-call. * Return `Elicit(message, Model)` to ask the user, only when you have to. Unwrapped annotations abort on decline; `ElicitationResult[T]` lets the tool branch. diff --git a/examples/stories/refund_desk/README.md b/examples/stories/refund_desk/README.md index 664995e13..153504041 100644 --- a/examples/stories/refund_desk/README.md +++ b/examples/stories/refund_desk/README.md @@ -62,6 +62,8 @@ uv run python -m stories.refund_desk.client --http within a round each resolver runs at most once, keyed by function identity. Across 2026 rounds only *elicited* outcomes persist (in `requestState`); a resolver that resolves without eliciting is pure and may re-run each round. + An eliciting resolver's body runs again too — once to ask, once more to + consume its answer. An answer is matched back to its question when the call resumes, so an eliciting resolver must derive its question deterministically from the tool's arguments and earlier answers; a per-call generated value (a