diff --git a/docs/advanced/multi-round-trip.md b/docs/advanced/multi-round-trip.md index 665808a5d..883a594e2 100644 --- a/docs/advanced/multi-round-trip.md +++ b/docs/advanced/multi-round-trip.md @@ -19,7 +19,7 @@ That's the whole protocol. Every leg is an ordinary request from the client to t ## The server side -On `@mcp.tool()` you rarely build this by hand: declare a dependency that asks the user and the SDK returns the `InputRequiredResult` for you - that form is the **[Dependencies](../tutorial/dependencies.md)** tutorial. The manual form is the **low-level** `Server`, whose `on_call_tool` handler is allowed to return either result type: +On `@mcp.tool()` you rarely build this by hand: declare a dependency that asks the user and the SDK returns the `InputRequiredResult` for you - that form is the **[Dependencies](../tutorial/dependencies.md)** tutorial. The two forms don't mix: a call has one `input_responses`/`request_state` channel, so a tool that uses `Resolve(...)` parameters cannot also return `InputRequiredResult` from its body. A declared `InputRequiredResult` return is rejected at registration (`InvalidSignature`), and an undeclared one fails the call at runtime. The manual form is the **low-level** `Server`, whose `on_call_tool` handler is allowed to return either result type: ```python title="server.py" hl_lines="44-47" --8<-- "docs_src/mrtr/tutorial001.py" diff --git a/docs/tutorial/dependencies.md b/docs/tutorial/dependencies.md index 4e91515ef..b7b18fe76 100644 --- a/docs/tutorial/dependencies.md +++ b/docs/tutorial/dependencies.md @@ -123,19 +123,21 @@ That's the right default for a precondition: no answer, no order. When declining answers it, and the `Client` retries the call for you (**[Multi-round-trip requests](../advanced/multi-round-trip.md)**). On **2025-11-25** and earlier it is a synchronous elicitation request mid-call. Each question is asked exactly once per call - a guarantee about the question, not the resolver. In the - multi-round-trip form an eliciting resolver runs again to consume its answer, so code before - its `return Elicit(...)` runs on the asking round and again on the answering one; a resolver - that answered *without* asking, like `check_stock`, may run again whenever the call resumes - after a question. When it resumes, each answer is matched back to its question, so an - eliciting resolver must derive its question deterministically from the tool's arguments and - earlier answers - a per-call generated value (a `default_factory` id, a timestamp) is - re-derived on each round and must not appear in a question the answer is meant to bind to. + multi-round-trip form any resolver may run again whenever the call resumes after a question, + so code before a `return Elicit(...)` runs on each of those rounds; the recorded answer then + satisfies the repeated question without prompting the user again. A recorded answer is only + ever consulted when the resolver asks; a resolver that answers *without* asking, like + `check_stock`, always supplies its own computed value. Because each answer is matched back to + its question, an eliciting resolver must derive its question deterministically from the + tool's arguments and earlier answers. A per-call generated value (a `default_factory` id, a + timestamp) is re-derived on each round and must not appear in a question the answer is meant + to bind to. ## Recap * `Annotated[T, Resolve(fn)]` on a tool parameter: the SDK runs `fn` and injects its return value. * A resolved parameter is invisible to the model and cannot be supplied by a client. Values the model must not invent - prices, identities, permissions - belong here. -* A resolver's parameters are resolved the same way: the `Context`, another `Resolve(...)`, or a tool argument by name. The graph runs each resolver at most once per round, however many consumers it has; each question is asked exactly once, an eliciting resolver runs again to consume its answer, and a resolver that never asked may run again when a call resumes. +* A resolver's parameters are resolved the same way: the `Context`, another `Resolve(...)`, or a tool argument by name. The graph runs each resolver at most once per round, however many consumers it has; each question is asked exactly once, and any resolver may run again when a call resumes after a question. * Bad graphs fail at registration with `InvalidSignature`, not mid-call. * Return `Elicit(message, Model)` to ask the user, only when you have to. Unwrapped annotations abort on decline; `ElicitationResult[T]` lets the tool branch. diff --git a/examples/stories/refund_desk/README.md b/examples/stories/refund_desk/README.md index 153504041..5b5bb5532 100644 --- a/examples/stories/refund_desk/README.md +++ b/examples/stories/refund_desk/README.md @@ -60,10 +60,11 @@ uv run python -m stories.refund_desk.client --http consumer can abort. - **Memoization scope.** Each question is asked at most once per call, and within a round each resolver runs at most once, keyed by function identity. - Across 2026 rounds only *elicited* outcomes persist (in `requestState`); a - resolver that resolves without eliciting is pure and may re-run each round. - An eliciting resolver's body runs again too — once to ask, once more to - consume its answer. + Across 2026 rounds only *elicited* outcomes persist (in `requestState`); any + resolver's body may run again on each round the call passes through. A + recorded answer is consulted only when the resolver asks its question again: + it satisfies the question without re-prompting the user, and it never stands + in for a value the resolver computes itself. An answer is matched back to its question when the call resumes, so an eliciting resolver must derive its question deterministically from the tool's arguments and earlier answers; a per-call generated value (a diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py index 323ce5cdd..9ff8dfeed 100644 --- a/src/mcp/server/mcpserver/resolve.py +++ b/src/mcp/server/mcpserver/resolve.py @@ -13,8 +13,10 @@ (independent resolvers are asked in one round; a resolver depending on another's answer is asked in a later round). At <= 2025-11-25 it issues a synchronous `elicitation/create` request mid-call. Only *elicited* outcomes are carried in -`request_state` across rounds (so the user is asked each question once); a -resolver that returns a value without eliciting is pure and may re-run each round. +`request_state` across rounds (so the user is asked each question once). Resolver +bodies may re-run on every round; a recorded outcome is consulted only when the +body asks its question again, so a resolver's own computation always wins over +anything the client echoes back in `request_state`. Whether the consumer receives the unwrapped model or the full `ElicitationResult` union is decided by the consumer's annotation: @@ -114,15 +116,11 @@ def __init__( fn: Callable[..., Any], params: dict[str, _ParamPlan], is_async: bool, - elicit_schema: type[BaseModel] | None, wire_key: str, ) -> None: self.fn = fn self.params = params self.is_async = is_async - # The `T` from the resolver's `Elicit[T]` return arm, if annotated. Used to - # re-validate an outcome restored from `request_state` into a model. - self.elicit_schema = elicit_schema # Deterministic, collision-free key for this resolver's elicitation on the # wire (`input_requests`/`request_state`). Assigned at registration so it is # stable across rounds even when `module:qualname` collides (closures). @@ -176,6 +174,37 @@ def find_resolved_parameters(fn: Callable[..., Any]) -> dict[str, tuple[Resolve, return resolved +def returns_input_required(fn: Callable[..., Any]) -> bool: + """True when `fn`'s return annotation carries an `InputRequiredResult` arm. + + Used at tool registration to reject combining `Resolve(...)` parameters with a + hand-rolled `InputRequiredResult` flow: a call has a single + `input_responses`/`request_state` channel, so the two flows would overwrite + each other's state and the call could never converge. + """ + return _has_input_required_arm(_type_hints(fn).get("return")) + + +def _has_input_required_arm(annotation: Any) -> bool: + """Walk an annotation's arms through `Annotated`, type aliases, and unions.""" + if get_origin(annotation) is Annotated: + return _has_input_required_arm(get_args(annotation)[0]) + # A `type X = ...` / `TypeAliasType` alias carries its target on `__value__` (a + # subscripted alias forwards the attribute to its origin). The access evaluates + # a PEP 695 alias lazily, so an alias naming things unavailable at runtime + # (TYPE_CHECKING-only imports) raises NameError; such an alias declares no arm + # this check can see, and the in-call guard in `Tool.run` still covers it. + try: + value = getattr(annotation, "__value__", None) + except NameError: + return False + if value is not None: + return _has_input_required_arm(value) + if _is_union(annotation): + return any(_has_input_required_arm(arg) for arg in get_args(annotation)) + return isinstance(annotation, type) and issubclass(annotation, InputRequiredResult) + + def _contains_resolve(annotation: Any) -> bool: """True when a `Resolve` marker is nested inside `annotation` (e.g. a union member).""" if get_origin(annotation) is Annotated: @@ -183,16 +212,12 @@ def _contains_resolve(annotation: Any) -> bool: return any(_contains_resolve(arg) for arg in get_args(annotation)) -def _elicit_return_schema(return_annotation: Any, name: str) -> type[BaseModel] | None: - """Extract `T` from a resolver return type's `Elicit[T]` arm, if present. - - Handles a bare `-> Elicit[T]` and a `-> T | Elicit[T]` union. Lets an elicited - outcome restored from `request_state` (a plain dict) be re-validated into its - model so dependent resolvers and tools receive a typed value. +def _check_elicit_return(return_annotation: Any, name: str) -> None: + """Validate the `Elicit[...]` arms of a resolver's return annotation. Raises: InvalidSignature: If the annotation has more than one `Elicit[...]` arm; - the runtime can honor only one static question schema per resolver. + a resolver asks one question - a second arm means it should be split. """ # A bare `Elicit[T]` is itself a candidate; a union contributes its members. candidates = get_args(return_annotation) if _is_union(return_annotation) else (return_annotation,) @@ -203,10 +228,6 @@ def _elicit_return_schema(return_annotation: Any, name: str) -> type[BaseModel] f"Resolver {name!r} return annotation has multiple Elicit arms; " "a resolver asks one question - split it into separate resolvers" ) - if not arms: - return None - schema = get_args(arms[0])[0] - return schema if isinstance(schema, type) and issubclass(schema, BaseModel) else None def _is_union(annotation: Any) -> bool: @@ -299,8 +320,8 @@ def analyze(fn: Callable[..., Any], stack: tuple[Hashable, ...]) -> None: "expected a Context, an Annotated[_, Resolve(...)], or a tool argument by name" ) - elicit_schema = _elicit_return_schema(hints.get("return"), _resolver_name(fn)) - plans[key] = _ResolverPlan(fn, params, is_async_callable(fn), elicit_schema, wire_key) + _check_elicit_return(hints.get("return"), _resolver_name(fn)) + plans[key] = _ResolverPlan(fn, params, is_async_callable(fn), wire_key) for dep in nested: analyze(dep, stack + (key,)) @@ -387,9 +408,10 @@ async def resolve_arguments( negotiated protocol is >= 2026-07-28), returns an `InputRequiredResult` carrying the batched questions instead; the tool body is not run. - An eliciting resolver asks its question once - its answer is carried in - `request_state` across rounds - while a resolver that resolves without - eliciting is pure and may re-run on each round. + Each question is asked once - its answer is carried in `request_state` across + rounds and satisfies the question when the resolver asks it again. Resolver + bodies themselves may re-run on each round; a recorded answer is consulted + only when the body asks, never in place of running it. Raises: ToolError: If an elicited value is declined or cancelled and the consumer @@ -428,15 +450,6 @@ async def _resolve(fn: Callable[..., Any], res: _Resolution) -> ElicitationResul if wire_key in res.pending: # Already asked this round by another consumer; don't run the resolver again. raise _Pending - # Restore a prior round's outcome directly only when its model is known from the - # `Elicit[T]` return arm. Without that (a resolver that elicits but isn't annotated - # `-> ... Elicit[T]`), fall through and re-run the resolver so `_elicit` can - # re-validate the stored answer against the live `Elicit.schema`. - if wire_key in res.state and (plan.elicit_schema is not None or res.state[wire_key].action != "accept"): - outcome = _restore_outcome(res, wire_key, plan.elicit_schema) - if outcome is not None: - res.cache[cache_key] = outcome - return outcome kwargs: dict[str, Any] = {} dep_pending = False @@ -481,10 +494,11 @@ async def _elicit(elicit: Elicit[Any], key: str, res: _Resolution) -> Elicitatio if not res.input_required: return await res.context.elicit(elicit.message, elicit.schema) - # Answered in a prior round (restored without a known schema, e.g. an unannotated - # resolver): re-validate the stored entry against the live `Elicit.schema`. A - # recorded outcome wins over a re-sent answer; an invalid entry self-deletes and - # falls through to the fresh answer (or to re-asking). + # A recorded outcome from a prior round is consulted only here, after the body + # decided to ask, so a `request_state` entry can never stand in for a resolver's + # own computation. Re-validate it against the live `Elicit.schema`. A recorded + # outcome wins over a re-sent answer; an invalid entry self-deletes and falls + # through to the fresh answer (or to re-asking). outcome = _restore_outcome(res, key, elicit.schema) if outcome is not None: return outcome @@ -614,24 +628,21 @@ def _encode_state(outcomes: Mapping[str, _StateEntry]) -> str: return _State(v=_STATE_VERSION, outcomes=dict(outcomes)).model_dump_json() -def _outcome_from_state(entry: _StateEntry, schema: type[BaseModel] | None) -> ElicitationResult[Any]: +def _outcome_from_state(entry: _StateEntry, schema: type[BaseModel]) -> ElicitationResult[Any]: """Rebuild an `ElicitationResult` from a decoded `request_state` entry. Raises: - ValidationError: If `schema` is known and the entry's data does not - validate against it. + ValidationError: If an accepted entry's data does not validate against + `schema` (the live `Elicit.schema` of the question being asked). """ if entry.action == "decline": return DeclinedElicitation() if entry.action == "cancel": return CancelledElicitation() - data = entry.data - if schema is not None: - data = schema.model_validate(data) - return _accepted(data) + return _accepted(schema.model_validate(entry.data)) -def _restore_outcome(res: _Resolution, key: str, schema: type[BaseModel] | None) -> ElicitationResult[Any] | None: +def _restore_outcome(res: _Resolution, key: str, schema: type[BaseModel]) -> ElicitationResult[Any] | None: """Restore `key`'s recorded outcome from a prior round, or `None` when absent. `request_state` is client-trusted, so an entry whose data fails validation gets @@ -665,4 +676,5 @@ def _restore_outcome(res: _Resolution, key: str, schema: type[BaseModel] | None) "find_resolved_parameters", "build_resolver_plans", "resolve_arguments", + "returns_input_required", ] diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index 50d28f574..23248707a 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -7,11 +7,12 @@ from mcp_types import Icon, InputRequiredResult, ToolAnnotations from pydantic import BaseModel, Field -from mcp.server.mcpserver.exceptions import ToolError +from mcp.server.mcpserver.exceptions import InvalidSignature, ToolError from mcp.server.mcpserver.resolve import ( build_resolver_plans, find_resolved_parameters, resolve_arguments, + returns_input_required, ) from mcp.server.mcpserver.utilities.context_injection import find_context_parameter from mcp.server.mcpserver.utilities.func_metadata import FuncMetadata, func_metadata @@ -81,6 +82,12 @@ def from_function( context_kwarg = find_context_parameter(fn) resolved_params = find_resolved_parameters(fn) + if resolved_params and returns_input_required(fn): + raise InvalidSignature( + f"Tool {func_name!r} combines Resolve(...) parameters with an InputRequiredResult " + "return; a call has one input_required channel, so the multi-round flow is driven " + "either by resolvers or by the tool body, not both" + ) skip_names = [context_kwarg] if context_kwarg is not None else [] skip_names.extend(resolved_params) @@ -150,6 +157,15 @@ async def run( pre_validated=pre_validated, ) + # Registration rejects the annotated form of this combination; this covers + # a body that returns an InputRequiredResult without declaring it. + if self.resolved_params and isinstance(result, InputRequiredResult): + raise ToolError( + "the tool returned an InputRequiredResult but its parameters use Resolve(...); " + "a call has one input_required channel, so the multi-round flow is driven " + "either by resolvers or by the tool body, not both" + ) + if convert_result: result = self.fn_metadata.convert_result(result) diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py index 7e92f1c4e..571cefcb6 100644 --- a/tests/server/mcpserver/test_resolve.py +++ b/tests/server/mcpserver/test_resolve.py @@ -3,7 +3,7 @@ import json from collections.abc import Callable from datetime import datetime -from typing import Annotated, Any, Literal +from typing import Annotated, Any, Literal, TypeVar import anyio import pytest @@ -18,7 +18,8 @@ InputResponses, TextContent, ) -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ValidationError, create_model +from typing_extensions import TypeAliasType from mcp import Client, InputRequiredRoundsExceededError from mcp.client import ClientRequestContext @@ -34,8 +35,8 @@ ) from mcp.server.mcpserver.exceptions import InvalidSignature from mcp.server.mcpserver.resolve import ( + _check_elicit_return, _decode_state, - _elicit_return_schema, _encode_state, _outcome_from_state, _resolver_key, @@ -43,6 +44,7 @@ _StateEntry, _uses_input_required, find_resolved_parameters, + returns_input_required, ) from mcp.server.mcpserver.tools.base import Tool from mcp.shared.exceptions import MCPError @@ -56,6 +58,26 @@ class Confirm(BaseModel): ok: bool +class Restock(BaseModel): + needed: bool + + +# The `type X = ...` spelling of an InputRequiredResult-bearing return annotation, +# bare and generic (a subscripted alias forwards `__value__` to its origin). +IRRAlias = TypeAliasType("IRRAlias", InputRequiredResult | str) +T_alias = TypeVar("T_alias") +IRRAliasGeneric = TypeAliasType("IRRAliasGeneric", InputRequiredResult | T_alias, type_params=(T_alias,)) + + +class _UnevaluableAlias: + """Stand-in for `type X = GhostType | str` whose names exist only under + TYPE_CHECKING: accessing `__value__` evaluates the alias and raises.""" + + @property + def __value__(self) -> Any: + raise NameError("name 'GhostType' is not defined") + + class Handle(BaseModel): user_name: str = Field(alias="userName") @@ -700,7 +722,7 @@ async def never(context: ClientRequestContext, params: ElicitRequestParams) -> E @pytest.mark.anyio -async def test_input_required_resolver_asks_and_consumes_then_never_reruns(): +async def test_input_required_asks_each_question_once_while_bodies_rerun(): mcp = MCPServer(name="ExactlyOnceMRTR") counts = {"login": 0, "confirm": 0} @@ -735,12 +757,13 @@ async def callback(context: ClientRequestContext, params: ElicitRequestParams) - # `confirm` can only form its question from `login`'s answer, so the auto-driver # sees the questions in two successive rounds and answers each exactly once. assert asked == ["Username?", "As octocat?"] - # An eliciting resolver runs twice - once to ask, once to consume the answer - - # then its outcome is carried in `request_state` and it never runs again. `login` - # asks in round 1 and is consumed in round 2; `confirm` (which depends on - # `login`) only forms its question once `login` is known, so it asks in round 2 - # and is consumed in round 3. Neither re-runs beyond consuming its own answer. - assert counts == {"login": 2, "confirm": 2} + # The once-per-call guarantee is about the question, not the body: a recorded + # answer is consulted only after the body asks again, so `login` runs on every + # round the call passes through (asks in round 1, consumes its answer in round 2, + # re-asks-and-restores in round 3) while the user is prompted exactly once. + # `confirm` only forms its question once `login` is known: it asks in round 2 + # and consumes in round 3. + assert counts == {"login": 3, "confirm": 2} @pytest.mark.anyio @@ -863,24 +886,23 @@ def test_state_round_trips_accept_decline_cancel(): accepted = _outcome_from_state(decoded["a"], Login) assert isinstance(accepted, AcceptedElicitation) and accepted.data == Login(username="octocat") - assert isinstance(_outcome_from_state(decoded["b"], None), DeclinedElicitation) - assert isinstance(_outcome_from_state(decoded["c"], None), CancelledElicitation) - raw = _outcome_from_state(decoded["d"], None) - assert isinstance(raw, AcceptedElicitation) and raw.data == "raw-token" - - -def test_elicit_return_schema_extraction(): - assert _elicit_return_schema(Elicit[Login], "r") is Login # bare Elicit[T] - assert _elicit_return_schema(Login | Elicit[Login], "r") is Login # union arm - assert _elicit_return_schema(Login, "r") is None # no Elicit arm - assert _elicit_return_schema(None, "r") is None - # The bound on `Elicit`'s parameter is unenforced at runtime, so a non-model - # subscription is constructible and must yield no schema rather than crash. - unbounded_elicit: Any = Elicit - assert _elicit_return_schema(unbounded_elicit[int], "r") is None - # Two distinct Elicit arms are ambiguous: the runtime can honor only one schema. + # Decline/cancel entries carry no data; the schema is not consulted for them. + assert isinstance(_outcome_from_state(decoded["b"], Login), DeclinedElicitation) + assert isinstance(_outcome_from_state(decoded["c"], Login), CancelledElicitation) + # An accepted restore always validates against the question's live schema - + # data that doesn't fit is rejected, never passed through raw. + with pytest.raises(ValidationError): + _outcome_from_state(decoded["d"], Login) + + +def test_check_elicit_return_allows_one_arm_and_rejects_two(): + _check_elicit_return(Elicit[Login], "r") # bare Elicit[T] + _check_elicit_return(Login | Elicit[Login], "r") # union arm + _check_elicit_return(Login, "r") # no Elicit arm + _check_elicit_return(None, "r") # unannotated + # A resolver asks one question: two distinct Elicit arms mean it should be split. with pytest.raises(InvalidSignature, match="'r' return annotation has multiple Elicit arms"): - _elicit_return_schema(Elicit[Login] | Elicit[Confirm], "r") + _check_elicit_return(Elicit[Login] | Elicit[Confirm], "r") @pytest.mark.anyio @@ -1195,13 +1217,14 @@ def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: @pytest.mark.anyio async def test_eliciting_resolver_without_elicit_arm_restores_a_typed_model(): - # A resolver annotated `-> Login` that actually returns `Elicit(...)` has no - # `Elicit[T]` return arm, so `elicit_schema` is None. Its answer, restored from - # request_state in a 3+ round flow, must still come back as a Login model (not a - # raw dict) so a dependent resolver/tool can use its attributes. + # A resolver annotated `-> object` that actually returns `Elicit(...)` declares + # no `Elicit[T]` return arm. Its answer, restored from request_state in a 3+ + # round flow, must still come back as a Login model (not a raw dict): restore + # validates against the live `Elicit.schema` the body produced, not the lying + # annotation, so a dependent resolver/tool can use its attributes. mcp = MCPServer(name="LyingAnnotation") - # Annotated without an `Elicit[T]` return arm, so `elicit_schema` is None. + # Annotated without an `Elicit[T]` return arm; the body asks anyway. async def login(ctx: Context) -> object: return Elicit("user?", Login) @@ -1591,3 +1614,181 @@ async def act( assert isinstance(final, CallToolResult) assert isinstance(final.content[0], TextContent) assert final.content[0].text == "octocat:True" + + +@pytest.mark.anyio +async def test_state_entry_never_replaces_a_resolver_computed_value(): + # `request_state` is client-echoed: an accept entry under a resolver's wire key + # must only satisfy a question the resolver is actually asking, never stand in + # for the body's own computation on a branch that does not ask. + mcp = MCPServer(name="StateVsBody") + calls = {"decide": 0} + + async def decide(ctx: Context) -> Restock | Elicit[Restock]: + calls["decide"] += 1 + return Restock(needed=False) # this branch computes server-side; no question + + @mcp.tool() + async def plan_restock(restock: Annotated[Restock, Resolve(decide)]) -> str: + return str(restock.needed) + + wire_key = f"{decide.__module__}:{decide.__qualname__}" + crafted = json.dumps({"v": 1, "outcomes": {wire_key: {"action": "accept", "data": {"needed": True}}}}) + + async with Client(mcp, elicitation_callback=_never) as client: + result = await client.session.call_tool("plan_restock", {}, request_state=crafted, allow_input_required=True) + assert isinstance(result, CallToolResult) + assert isinstance(result.content[0], TextContent) + # The body ran and its computation won; the crafted entry was never consulted. + assert result.content[0].text == "False" + assert calls["decide"] == 1 + + +@pytest.mark.anyio +async def test_state_decline_entry_for_a_pure_resolver_is_ignored(): + # A decline/cancel entry can only answer a question; a resolver with no Elicit + # arm never asks one, so such an entry cannot suppress its computed value. + mcp = MCPServer(name="PureVsDecline") + + async def lookup(ctx: Context) -> Login: + return Login(username="server-side") + + @mcp.tool() + async def whoami(login: Annotated[Login, Resolve(lookup)]) -> str: + return login.username + + wire_key = f"{lookup.__module__}:{lookup.__qualname__}" + crafted = json.dumps({"v": 1, "outcomes": {wire_key: {"action": "decline"}}}) + + async with Client(mcp, elicitation_callback=_never) as client: + result = await client.session.call_tool("whoami", {}, request_state=crafted, allow_input_required=True) + assert isinstance(result, CallToolResult) + assert not result.is_error + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "server-side" + + +@pytest.mark.anyio +async def test_dynamic_schema_resolver_restores_across_rounds(): + # `-> Elicit[BaseModel]` is the natural annotation for `create_model(...)` + # schemas; the restored answer must validate against the live question's + # schema, so the dynamic shape works across a multi-question chain. + mcp = MCPServer(name="DynamicSchema") + dyn = create_model("Dyn", token=(str, ...)) + + async def first(ctx: Context) -> Elicit[BaseModel]: + return Elicit("Q1?", dyn) + + async def second(f: Annotated[BaseModel, Resolve(first)], ctx: Context) -> Elicit[Confirm]: + return Elicit("Q2?", Confirm) + + @mcp.tool() + async def chain(c: Annotated[Confirm, Resolve(second)]) -> str: + return str(c.ok) + + def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: + if "Q1" in params.message: + return ElicitResult(action="accept", content={"token": "t"}) + return ElicitResult(action="accept", content={"ok": True}) + + async with Client(mcp, elicitation_callback=_never) as client: + one = await client.session.call_tool("chain", {}, allow_input_required=True) + assert isinstance(one, InputRequiredResult) + two = await client.session.call_tool( + "chain", + {}, + input_responses=_answer_round(one, answer), + request_state=one.request_state, + allow_input_required=True, + ) + assert isinstance(two, InputRequiredResult) # Q1 consumed, Q2 asked + final = await client.session.call_tool( + "chain", + {}, + input_responses=_answer_round(two, answer), + request_state=two.request_state, + allow_input_required=True, + ) + # Round 3 restores Q1's answer against the live dynamic schema and completes. + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "True" + + +@pytest.mark.parametrize( + "annotation", + [ + InputRequiredResult, + InputRequiredResult | str, + Annotated[InputRequiredResult | str, "meta"], + str | Annotated[InputRequiredResult, "meta"], # Annotated as a union member + IRRAlias, # `type X = ...` alias + IRRAliasGeneric[str], # subscripted generic alias + ], +) +def test_tool_combining_resolvers_with_input_required_return_is_rejected(annotation: Any): + # A call has one input_responses/request_state channel: resolver elicitation + # and a hand-rolled InputRequiredResult body cannot share it. + mcp = MCPServer(name="ChannelOwnership") + + async def lookup(ctx: Context) -> Login: + return Login(username="x") # pragma: no cover - registration is rejected + + async def combo(login: Annotated[Login, Resolve(lookup)]): + raise NotImplementedError # pragma: no cover + + combo.__annotations__["return"] = annotation + with pytest.raises(InvalidSignature, match="combines Resolve\\(\\.\\.\\.\\) parameters"): + mcp.tool()(combo) + + # Without resolver parameters the hand-rolled form remains available. + @mcp.tool() + async def manual() -> InputRequiredResult: + raise NotImplementedError # pragma: no cover - only registration is exercised + + assert returns_input_required(manual) + + +def test_unevaluable_alias_and_parameterized_generics_declare_no_arm(): + # A `type X = ...` alias is evaluated lazily, so one naming TYPE_CHECKING-only + # imports raises NameError on `__value__` access: it declares no arm the check + # can see and must not break registration (the in-call guard still covers a + # body that returns an InputRequiredResult anyway). A parameterized generic + # return is never the InputRequiredResult class either. + mcp = MCPServer(name="RegistrationTolerance") + + async def lookup(ctx: Context) -> Login: + return Login(username="x") # pragma: no cover - only registration is exercised + + async def lazy(login: Annotated[Login, Resolve(lookup)]): + raise NotImplementedError # pragma: no cover + + lazy.__annotations__["return"] = _UnevaluableAlias() + assert not returns_input_required(lazy) + + @mcp.tool() + async def listy(login: Annotated[Login, Resolve(lookup)]) -> list[str]: + raise NotImplementedError # pragma: no cover + + assert not returns_input_required(listy) + + +@pytest.mark.anyio +async def test_tool_returning_input_required_dynamically_with_resolvers_is_an_error(): + # The annotated form of this combination is rejected at registration; a body + # that returns an InputRequiredResult without declaring it fails loudly at the + # same boundary instead of silently fighting the resolvers for the channel. + mcp = MCPServer(name="DynamicChannelClash") + + async def lookup(ctx: Context) -> Login: + return Login(username="x") + + @mcp.tool() + async def sneaky(login: Annotated[Login, Resolve(lookup)]): + return InputRequiredResult(input_requests={}, request_state="opaque") + + async with Client(mcp) as client: + result = await client.call_tool("sneaky", {}) + assert result.is_error + assert isinstance(result.content[0], TextContent) + assert "the multi-round flow is driven either by resolvers or by the tool body" in result.content[0].text