diff --git a/docs/concepts/elicitation/elicitation.md b/docs/concepts/elicitation/elicitation.md index 94597fa5f..78782bfbb 100644 --- a/docs/concepts/elicitation/elicitation.md +++ b/docs/concepts/elicitation/elicitation.md @@ -170,6 +170,61 @@ Here's an example implementation of how a console application might handle elici [!code-csharp[](samples/client/Program.cs?name=snippet_ElicitationHandler)] +### Multi Round-Trip Requests (MRTR) + +[MRTR](xref:mrtr) is the SEP-2322 mechanism for server-driven input requests, finalized in protocol revision `DRAFT-2026-v1`. Under the draft protocol, the server-to-client `elicitation/create` request method is removed; the recommended way to ask the user for input from a server handler is to throw and let the SDK emit an on the wire. + +> [!IMPORTANT] +> `ElicitAsync` throws `InvalidOperationException("Elicitation is not supported in stateless mode.")` whenever the server is running stateless — which includes every Streamable HTTP server under `DRAFT-2026-v1` once that revision is forced to stateless-only in a future PR. Stdio servers and current-protocol stateful Streamable HTTP servers continue to work via the legacy server-to-client `elicitation/create` request flow. For code that needs to run on stateless servers — including all `DRAFT-2026-v1` Streamable HTTP servers going forward — throw `InputRequiredException` from your handler instead. It works under both protocols and both session modes. + +For example: + +```csharp +[McpServerTool, Description("Tool that elicits via MRTR")] +public static string ElicitWithMrtr( + McpServer server, + RequestContext context) +{ + // On retry, process the client's elicitation response + if (context.Params!.InputResponses?.TryGetValue("user_input", out var response) is true) + { + var elicitResult = response.Deserialize(InputResponse.ElicitResultJsonTypeInfo); + return elicitResult?.Action == "accept" + ? $"User accepted: {elicitResult.Content?.FirstOrDefault().Value}" + : "User declined."; + } + + if (!server.IsMrtrSupported) + { + return "This tool requires MRTR support (DRAFT-2026-v1, or a stateful current-protocol session)."; + } + + // First call — request user input + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["user_input"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "Please confirm the action", + RequestedSchema = new() + { + Properties = new Dictionary + { + ["confirm"] = new ElicitRequestParams.BooleanSchema + { + Description = "Confirm the action" + } + } + } + }) + }, + requestState: "awaiting-confirmation"); +} +``` + +> [!TIP] +> See [Multi Round-Trip Requests (MRTR)](xref:mrtr) for the full protocol details, including multiple round trips, concurrent input requests, and the compatibility matrix. + ### URL Elicitation Required Error When a tool cannot proceed without first completing a URL-mode elicitation (for example, when third-party OAuth authorization is needed), and calling `ElicitAsync` is not practical (for example in [stateless](xref:stateless) mode where server-to-client requests are disabled), the server may throw a . This is a specialized error (JSON-RPC error code `-32042`) that signals to the client that one or more URL-mode elicitations must be completed before the original request can be retried. diff --git a/docs/concepts/index.md b/docs/concepts/index.md index 6393d9997..9e5a90f25 100644 --- a/docs/concepts/index.md +++ b/docs/concepts/index.md @@ -18,6 +18,7 @@ Install the SDK and build your first MCP client and server. | [Progress tracking](progress/progress.md) | Learn how to track progress for long-running operations through notification messages. | | [Cancellation](cancellation/cancellation.md) | Learn how to cancel in-flight MCP requests using cancellation tokens and notifications. | | [Tasks](tasks/tasks.md) | Learn how to use task-based execution for long-running operations that can be polled for status and results. | +| [Multi Round-Trip Requests (MRTR)](mrtr/mrtr.md) | Learn how servers request client input during tool execution using input-required results and retries. | ### Client Features diff --git a/docs/concepts/mrtr/mrtr.md b/docs/concepts/mrtr/mrtr.md new file mode 100644 index 000000000..1d1ebce32 --- /dev/null +++ b/docs/concepts/mrtr/mrtr.md @@ -0,0 +1,291 @@ +--- +title: Multi Round-Trip Requests (MRTR) +author: halter73 +description: How servers request client input during tool execution using Multi Round-Trip Requests. +uid: mrtr +--- + +# Multi Round-Trip Requests (MRTR) + + +> [!WARNING] +> MRTR is part of the **`DRAFT-2026-v1`** revision of the MCP specification ([SEP-2322](https://github.com/modelcontextprotocol/modelcontextprotocol/pull/2322)). The wire format and API surface may change before the revision is ratified. See the [Experimental APIs](../../experimental.md) documentation for details on working with experimental APIs. + +Multi Round-Trip Requests (MRTR) let a server tool request input from the client — such as [elicitation](xref:elicitation), [sampling](xref:sampling), or [roots](xref:roots) — as part of a single tool call, without requiring a separate server-to-client JSON-RPC request for each interaction. Instead of returning a final result, the server returns an **incomplete result** containing one or more input requests. The client fulfills those requests and retries the original tool call with the responses attached. + +## Overview + +MRTR is useful when: + +- A tool needs user confirmation before proceeding (elicitation). +- A tool needs LLM reasoning from the client (sampling). +- A tool needs an updated list of client roots. +- A tool needs to perform multiple rounds of interaction in a single logical operation. +- A stateless server needs to orchestrate multi-step flows without keeping handler state in memory between rounds. + +## How MRTR works + +1. The client calls a tool on the server via `tools/call`. +2. The server tool determines it needs client input and returns an `InputRequiredResult` containing `inputRequests` and/or `requestState`. +3. The client resolves each input request (for example by prompting the user for elicitation, calling an LLM for sampling, or listing its roots). +4. The client retries the original `tools/call` with `inputResponses` (keyed to the input requests) and `requestState` echoed back. +5. The server processes the responses and either returns a final result or another `InputRequiredResult` for additional rounds. + +## Opting in + +MRTR activates when both peers negotiate protocol revision **`DRAFT-2026-v1`** during `initialize`. The C# SDK opts in by listing `DRAFT-2026-v1` as a supported protocol version on the client; servers automatically accept it when offered. No experimental flags are required. + +```csharp +// Client +var clientOptions = new McpClientOptions +{ + ProtocolVersion = "DRAFT-2026-v1", + Handlers = new McpClientHandlers + { + ElicitationHandler = HandleElicitationAsync, + SamplingHandler = HandleSamplingAsync, + } +}; +``` + +Under `DRAFT-2026-v1`, MRTR is the recommended way to obtain client input from a server handler. The spec removes the legacy server-to-client `elicitation/create`, `sampling/createMessage`, and `roots/list` request methods, so any code that needs to work on a `DRAFT-2026-v1` Streamable HTTP server (which will be stateless-only in a future revision) must use `InputRequiredException` rather than , , or . The legacy methods still work on stateful sessions — that's how stdio servers keep working under draft today — but they throw `InvalidOperationException("X is not supported in stateless mode.")` on any stateless session, current or draft. + +Under the current protocol revision (`2025-06-18` and earlier), `InputRequiredException` is still supported in stateful sessions via a backward-compatibility resolver — see [Compatibility](#compatibility) below. + +## Authoring an MRTR tool + +A tool participates in MRTR by throwing with an describing what it needs. On retry, the client's responses arrive on the request parameters and the tool inspects them to decide what to do next. + +### Checking MRTR support + +Tools should check before throwing `InputRequiredException`. It returns `true` when either: + +- The negotiated protocol revision is `DRAFT-2026-v1` (MRTR is native), or +- The session is stateful under the current protocol (the SDK can resolve input requests via legacy JSON-RPC and retry the handler). + +```csharp +[McpServerTool, Description("A tool that uses MRTR")] +public static string MyTool( + McpServer server, + RequestContext context) +{ + if (!server.IsMrtrSupported) + { + return "This tool requires a client that negotiates DRAFT-2026-v1, " + + "or a stateful current-protocol session."; + } + + // ... MRTR logic +} +``` + +### Returning an incomplete result + +Throw to return an incomplete result. The exception carries an containing `inputRequests` and/or `requestState`: + +```csharp +[McpServerTool, Description("Tool managing its own MRTR flow")] +public static string AnswerTool( + McpServer server, + RequestContext context, + [Description("The user's question")] string question) +{ + var requestState = context.Params!.RequestState; + var inputResponses = context.Params!.InputResponses; + + // On retry, process the client's responses + if (requestState is not null && inputResponses is not null) + { + var elicitResult = inputResponses["user_answer"].Deserialize(InputResponse.ElicitResultJsonTypeInfo); + return $"You answered: {elicitResult?.Content?.FirstOrDefault().Value}"; + } + + if (!server.IsMrtrSupported) + { + return "MRTR is not supported by this client."; + } + + // First call — request user input + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["user_answer"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = $"Please answer: {question}", + RequestedSchema = new() + { + Properties = new Dictionary + { + ["answer"] = new ElicitRequestParams.StringSchema + { + Description = "Your answer" + } + } + } + }) + }, + requestState: "awaiting-answer"); +} +``` + +### Accessing retry data + +When the client retries a tool call, the retry data is available on the request parameters: + +- — a dictionary of client responses keyed by the same keys used in `inputRequests`. +- — the opaque state string echoed back by the client. + +Use with the `JsonTypeInfo` matching the response type. The expected type follows from the matching in the original `inputRequests` map — there is no on-the-wire discriminator. + +- Elicitation — `response.Deserialize(InputResponse.ElicitResultJsonTypeInfo)` +- Sampling — `response.Deserialize(InputResponse.CreateMessageResultJsonTypeInfo)` +- Roots list — `response.Deserialize(InputResponse.ListRootsResultJsonTypeInfo)` + +### Load shedding with requestState-only responses + +A server can return a `requestState`-only incomplete result (without any `inputRequests`) to defer processing. This is useful for load shedding or breaking up long-running work across multiple requests: + +```csharp +[McpServerTool, Description("Tool that defers work using requestState")] +public static string DeferredTool( + McpServer server, + RequestContext context) +{ + var requestState = context.Params!.RequestState; + + if (requestState is not null) + { + // Resume deferred work + var state = JsonSerializer.Deserialize( + Convert.FromBase64String(requestState)); + return $"Completed step {state!.Step}"; + } + + if (!server.IsMrtrSupported) + { + return "MRTR is not supported by this client."; + } + + // Defer work to a later retry + var initialState = new MyState { Step = 1 }; + throw new InputRequiredException( + requestState: Convert.ToBase64String( + JsonSerializer.SerializeToUtf8Bytes(initialState))); +} +``` + +The client automatically retries `requestState`-only incomplete results, echoing the state back without needing to resolve any input requests. + +### Multiple round trips + +A tool can perform multiple rounds of interaction by throwing `InputRequiredException` multiple times across retries. Use `requestState` to track which round you're on: + +```csharp +[McpServerTool, Description("Multi-step wizard")] +public static string WizardTool( + McpServer server, + RequestContext context) +{ + var requestState = context.Params!.RequestState; + var inputResponses = context.Params!.InputResponses; + + if (requestState == "step-2" && inputResponses is not null) + { + var name = inputResponses["name"].Deserialize(InputResponse.ElicitResultJsonTypeInfo)?.Content?.FirstOrDefault().Value; + var age = inputResponses["age"].Deserialize(InputResponse.ElicitResultJsonTypeInfo)?.Content?.FirstOrDefault().Value; + return $"Welcome, {name}! You are {age} years old."; + } + + if (requestState == "step-1" && inputResponses is not null) + { + var name = inputResponses["name"].Deserialize(InputResponse.ElicitResultJsonTypeInfo)?.Content?.FirstOrDefault().Value; + + // Second round — ask for age + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["age"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = $"Hi {name}! How old are you?", + RequestedSchema = new() + { + Properties = new Dictionary + { + ["age"] = new ElicitRequestParams.NumberSchema + { + Description = "Your age" + } + } + } + }) + }, + requestState: "step-2"); + } + + if (!server.IsMrtrSupported) + { + return "MRTR is not supported. Please use a compatible client."; + } + + // First round — ask for name + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["name"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "What's your name?", + RequestedSchema = new() + { + Properties = new Dictionary + { + ["name"] = new ElicitRequestParams.StringSchema + { + Description = "Your name" + } + } + } + }) + }, + requestState: "step-1"); +} +``` + +### Providing custom error messages + +When MRTR is not supported, you can provide domain-specific guidance: + +```csharp +if (!server.IsMrtrSupported) +{ + return "This tool requires interactive input. To use it:\n" + + "1. Connect with a client that negotiates MCP protocol revision DRAFT-2026-v1, or\n" + + "2. Use a stateful current-protocol session so the server can resolve the input requests for you.\n" + + "\nStateless current-protocol sessions cannot resolve MRTR input requests."; +} +``` + +## Compatibility + +The SDK supports `InputRequiredException` across two protocol revisions and two session modes: + +| Negotiated protocol | Session mode | Behavior | +|---|---|---| +| `DRAFT-2026-v1` | Stateful | Native MRTR — `InputRequiredResult` is serialized directly to the wire. | +| `DRAFT-2026-v1` | Stateless | Native MRTR — `InputRequiredResult` is serialized directly to the wire. No server-side handler state needed. | +| Current (`2025-06-18` and earlier) | Stateful | Backward-compatibility resolver — the SDK sends standard `elicitation/create` / `sampling/createMessage` / `roots/list` JSON-RPC requests to the client, collects the responses, and retries the handler with `inputResponses` populated. Up to 10 retry rounds. | +| Current (`2025-06-18` and earlier) | Stateless | **Not supported** — `InputRequiredException` raises an `McpException`. The client doesn't speak MRTR, and the server can't resolve input requests via JSON-RPC without a persistent session. | + +> [!NOTE] +> The backcompat resolver is intentionally limited to 10 retry rounds. Tools that need more rounds should require `DRAFT-2026-v1` (check `IsMrtrSupported`). + +### Why `ElicitAsync` / `SampleAsync` / `RequestRootsAsync` throw on stateless servers + +`ElicitAsync` / `SampleAsync` / `RequestRootsAsync` issue a JSON-RPC request to the client and wait for the response on the same session. Stateless servers don't have a persistent session to wait on, so the SDK fails fast with `InvalidOperationException("X is not supported in stateless mode.")` (the check is `McpServer.ClientCapabilities is null`, which is the SDK's proxy for stateless). + +Under the current protocol revision (`2025-06-18` and earlier), stdio and stateful Streamable HTTP keep `ClientCapabilities` populated, so the legacy methods work normally and remain the recommended way to do one-shot client interactions. Under `DRAFT-2026-v1`, the spec removes those request methods from Streamable HTTP entirely; the SDK still allows the legacy methods on draft stdio sessions because stdio is implicitly single-process / stateful and the client handler is wired up regardless of negotiated revision. `InputRequiredException` is the way to write tools that work on every supported configuration. + +### Future direction + +The `DRAFT-2026-v1` revision is moving toward a stateless-only model: `Mcp-Session-Id` is being removed, and Streamable HTTP servers will run statelessly by default under the draft revision. When that lands, the `Stateful` row for `DRAFT-2026-v1` in the compatibility matrix above collapses into the `Stateless` row (Streamable HTTP under draft becomes stateless-only), and `InputRequiredException` becomes uniformly required for non-stdio servers. The current-protocol resolver path will remain for backward compatibility with older clients and stateful servers. + +This work is a follow-up to the present PR. diff --git a/docs/concepts/roots/roots.md b/docs/concepts/roots/roots.md index 7c09e53ad..213d317c0 100644 --- a/docs/concepts/roots/roots.md +++ b/docs/concepts/roots/roots.md @@ -103,3 +103,43 @@ server.RegisterNotificationHandler( Console.WriteLine($"Roots updated. {result.Roots.Count} roots available."); }); ``` + +### Multi Round-Trip Requests (MRTR) + +[MRTR](xref:mrtr) is the SEP-2322 mechanism for server-driven input requests, finalized in protocol revision `DRAFT-2026-v1`. Under the draft protocol, the server-to-client `roots/list` request method is removed; the recommended way to ask the client for its roots from a server handler is to throw and let the SDK emit an on the wire. + +> [!IMPORTANT] +> `RequestRootsAsync` throws `InvalidOperationException("Roots are not supported in stateless mode.")` whenever the server is running stateless — which includes every Streamable HTTP server under `DRAFT-2026-v1` once that revision is forced to stateless-only in a future PR. Stdio servers and current-protocol stateful Streamable HTTP servers continue to work via the legacy server-to-client `roots/list` request flow. For code that needs to run on stateless servers — including all `DRAFT-2026-v1` Streamable HTTP servers going forward — throw `InputRequiredException` from your handler instead. It works under both protocols and both session modes. + +For example: + +```csharp +[McpServerTool, Description("Tool that requests roots via MRTR")] +public static string ListRootsWithMrtr( + McpServer server, + RequestContext context) +{ + // On retry, process the client's roots response + if (context.Params!.InputResponses?.TryGetValue("get_roots", out var response) is true) + { + var roots = response.Deserialize(InputResponse.ListRootsResultJsonTypeInfo)?.Roots ?? []; + return $"Found {roots.Count} roots: {string.Join(", ", roots.Select(r => r.Uri))}"; + } + + if (!server.IsMrtrSupported) + { + return "This tool requires MRTR support (DRAFT-2026-v1, or a stateful current-protocol session)."; + } + + // First call — request the client's root list + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["get_roots"] = InputRequest.ForRootsList(new ListRootsRequestParams()) + }, + requestState: "awaiting-roots"); +} +``` + +> [!TIP] +> See [Multi Round-Trip Requests (MRTR)](xref:mrtr) for the full protocol details, including load shedding, multiple round trips, and the compatibility matrix. diff --git a/docs/concepts/sampling/sampling.md b/docs/concepts/sampling/sampling.md index 4f14a4ee0..bac6ed5ab 100644 --- a/docs/concepts/sampling/sampling.md +++ b/docs/concepts/sampling/sampling.md @@ -120,3 +120,55 @@ McpClientOptions options = new() ### Capability negotiation Sampling requires the client to advertise the `sampling` capability. This is handled automatically — when a is set, the client includes the sampling capability during initialization. The server can check whether the client supports sampling before calling ; if sampling is not supported, the method throws . + +### Multi Round-Trip Requests (MRTR) + +[MRTR](xref:mrtr) is the SEP-2322 mechanism for server-driven input requests, finalized in protocol revision `DRAFT-2026-v1`. Under the draft protocol, the server-to-client `sampling/createMessage` request method is removed; the recommended way to ask the client to sample from a server handler is to throw and let the SDK emit an on the wire. + +> [!IMPORTANT] +> `SampleAsync` and `AsSamplingChatClient` throw `InvalidOperationException("Sampling is not supported in stateless mode.")` whenever the server is running stateless — which includes every Streamable HTTP server under `DRAFT-2026-v1` once that revision is forced to stateless-only in a future PR. Stdio servers and current-protocol stateful Streamable HTTP servers continue to work via the legacy server-to-client `sampling/createMessage` request flow. For code that needs to run on stateless servers — including all `DRAFT-2026-v1` Streamable HTTP servers going forward — throw `InputRequiredException` from your handler instead. It works under both protocols and both session modes. + +For example: + +```csharp +[McpServerTool, Description("Tool that samples via MRTR")] +public static string SampleWithMrtr( + McpServer server, + RequestContext context) +{ + // On retry, process the client's sampling response + if (context.Params!.InputResponses?.TryGetValue("llm_call", out var response) is true) + { + var text = response.Deserialize(InputResponse.CreateMessageResultJsonTypeInfo)?.Content + .OfType().FirstOrDefault()?.Text; + return $"LLM said: {text}"; + } + + if (!server.IsMrtrSupported) + { + return "This tool requires MRTR support (DRAFT-2026-v1, or a stateful current-protocol session)."; + } + + // First call — request LLM completion from the client + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["llm_call"] = InputRequest.ForSampling(new CreateMessageRequestParams + { + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = [new TextContentBlock { Text = "Summarize the data" }] + } + ], + MaxTokens = 256 + }) + }, + requestState: "awaiting-sample"); +} +``` + +> [!TIP] +> See [Multi Round-Trip Requests (MRTR)](xref:mrtr) for the full protocol details, including load shedding, multiple round trips, and the compatibility matrix. diff --git a/docs/concepts/toc.yml b/docs/concepts/toc.yml index bd5474338..e0708cb75 100644 --- a/docs/concepts/toc.yml +++ b/docs/concepts/toc.yml @@ -19,6 +19,8 @@ items: uid: cancellation - name: Tasks uid: tasks + - name: Multi Round-Trip Requests (MRTR) + uid: mrtr - name: Client Features items: - name: Sampling diff --git a/src/Common/Experimentals.cs b/src/Common/Experimentals.cs index 7e7e969bb..e356480ed 100644 --- a/src/Common/Experimentals.cs +++ b/src/Common/Experimentals.cs @@ -110,4 +110,23 @@ internal static class Experimentals /// URL for the experimental RunSessionHandler API. /// public const string RunSessionHandler_Url = "https://github.com/modelcontextprotocol/csharp-sdk/blob/main/docs/list-of-diagnostics.md#mcpexp002"; + + /// + /// Diagnostic ID for the experimental Multi Round-Trip Requests (MRTR) feature. + /// + /// + /// This uses the same diagnostic ID as because MRTR + /// is an experimental feature in the MCP specification (SEP-2322). + /// + public const string Mrtr_DiagnosticId = "MCPEXP001"; + + /// + /// Message for the experimental MRTR feature. + /// + public const string Mrtr_Message = "The Multi Round-Trip Requests (MRTR) feature is experimental per the MCP specification (SEP-2322) and is subject to change."; + + /// + /// URL for the experimental MRTR feature. + /// + public const string Mrtr_Url = "https://github.com/modelcontextprotocol/csharp-sdk/blob/main/docs/list-of-diagnostics.md#mcpexp001"; } diff --git a/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj b/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj index 762091667..44bac1bc9 100644 --- a/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj +++ b/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj @@ -10,6 +10,7 @@ ASP.NET Core extensions for the C# Model Context Protocol (MCP) SDK. README.md true + $(NoWarn);MCPEXP001 diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 49922b8d9..fb6b822fa 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -495,12 +495,25 @@ internal static string MakeNewSessionId() // Implementation for reading a JSON-RPC message from the request body var message = await context.Request.ReadFromJsonAsync(s_messageTypeInfo, context.RequestAborted); - if (context.User?.Identity?.IsAuthenticated == true && message is not null) + if (message is not null) { - message.Context = new() + var protocolVersion = context.Request.Headers[McpProtocolVersionHeaderName].ToString(); + var isAuthenticated = context.User?.Identity?.IsAuthenticated == true; + + if (isAuthenticated || !string.IsNullOrEmpty(protocolVersion)) { - User = context.User, - }; + message.Context ??= new(); + + if (isAuthenticated) + { + message.Context.User = context.User; + } + + if (!string.IsNullOrEmpty(protocolVersion)) + { + message.Context.ProtocolVersion = protocolVersion; + } + } } return message; diff --git a/src/ModelContextProtocol.Core/Client/McpClientImpl.cs b/src/ModelContextProtocol.Core/Client/McpClientImpl.cs index 66410b272..a90ddf6e8 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientImpl.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientImpl.cs @@ -3,6 +3,7 @@ using ModelContextProtocol.Protocol; using System.Collections.Concurrent; using System.Text.Json; +using System.Text.Json.Nodes; namespace ModelContextProtocol.Client; @@ -142,6 +143,8 @@ private void RegisterHandlers(McpClientOptions options, NotificationHandlers not RequestMethods.SamplingCreateMessage, async (request, jsonRpcRequest, cancellationToken) => { + WarnIfLegacyRequestOnMrtrSession(RequestMethods.SamplingCreateMessage); + // Check if this is a task-augmented request if (request?.Task is { } taskMetadata) { @@ -176,10 +179,14 @@ private void RegisterHandlers(McpClientOptions options, NotificationHandlers not { requestHandlers.Set( RequestMethods.SamplingCreateMessage, - (request, _, cancellationToken) => samplingHandler( - request, - request?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, - cancellationToken), + (request, _, cancellationToken) => + { + WarnIfLegacyRequestOnMrtrSession(RequestMethods.SamplingCreateMessage); + return samplingHandler( + request, + request?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, + cancellationToken); + }, McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, McpJsonUtilities.JsonContext.Default.CreateMessageResult); } @@ -192,7 +199,11 @@ private void RegisterHandlers(McpClientOptions options, NotificationHandlers not { requestHandlers.Set( RequestMethods.RootsList, - (request, _, cancellationToken) => rootsHandler(request, cancellationToken), + (request, _, cancellationToken) => + { + WarnIfLegacyRequestOnMrtrSession(RequestMethods.RootsList); + return rootsHandler(request, cancellationToken); + }, McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, McpJsonUtilities.JsonContext.Default.ListRootsResult); @@ -209,6 +220,8 @@ private void RegisterHandlers(McpClientOptions options, NotificationHandlers not RequestMethods.ElicitationCreate, async (request, jsonRpcRequest, cancellationToken) => { + WarnIfLegacyRequestOnMrtrSession(RequestMethods.ElicitationCreate); + // Check if this is a task-augmented request if (request?.Task is { } taskMetadata) { @@ -241,6 +254,7 @@ private void RegisterHandlers(McpClientOptions options, NotificationHandlers not RequestMethods.ElicitationCreate, async (request, _, cancellationToken) => { + WarnIfLegacyRequestOnMrtrSession(RequestMethods.ElicitationCreate); var result = await elicitationHandler(request, cancellationToken).ConfigureAwait(false); return ElicitResult.WithDefaults(request, result); }, @@ -547,6 +561,98 @@ private void RegisterTaskHandlers(RequestHandlers requestHandlers, IMcpTaskStore /// public override Task Completion => _sessionHandler.CompletionTask; + /// + private async ValueTask> ResolveInputRequestsAsync( + IDictionary inputRequests, + CancellationToken cancellationToken) + { + // Resolve all input requests concurrently. If any fails, cancel the rest so user-facing + // handlers (sampling/elicitation prompts) don't keep running for a request whose caller + // has already given up, and ensure exceptions from late-completing tasks are observed. + using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + + var keyed = new (string Key, Task Task)[inputRequests.Count]; + int i = 0; + foreach (var kvp in inputRequests) + { + keyed[i++] = (kvp.Key, ResolveInputRequestAsync(kvp.Value, linkedCts.Token)); + } + + try + { + await Task.WhenAll(Array.ConvertAll(keyed, k => k.Task)).ConfigureAwait(false); + } + catch + { + linkedCts.Cancel(); + try + { + await Task.WhenAll(Array.ConvertAll(keyed, k => k.Task)).ConfigureAwait(false); + } + catch + { + // Observed; the original exception is the one we want to surface. + } + throw; + } + + var responses = new Dictionary(keyed.Length); + foreach (var (key, task) in keyed) + { + responses[key] = task.Result; + } + return responses; + } + + private async Task ResolveInputRequestAsync(InputRequest inputRequest, CancellationToken cancellationToken) + { + switch (inputRequest.Method) + { + case RequestMethods.SamplingCreateMessage: + if (_options.Handlers.SamplingHandler is { } samplingHandler) + { + var samplingParams = inputRequest.SamplingParams + ?? throw new McpException($"Failed to deserialize sampling parameters from MRTR input request."); + var result = await samplingHandler( + samplingParams, + samplingParams.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, + cancellationToken).ConfigureAwait(false); + return InputResponse.FromSamplingResult(result); + } + + throw new InvalidOperationException( + $"Server sent a sampling input request, but no {nameof(McpClientHandlers.SamplingHandler)} is registered."); + + case RequestMethods.ElicitationCreate: + if (_options.Handlers.ElicitationHandler is { } elicitationHandler) + { + var elicitParams = inputRequest.ElicitationParams + ?? throw new McpException($"Failed to deserialize elicitation parameters from MRTR input request."); + var result = await elicitationHandler(elicitParams, cancellationToken).ConfigureAwait(false); + result = ElicitResult.WithDefaults(elicitParams, result); + return InputResponse.FromElicitResult(result); + } + + throw new InvalidOperationException( + $"Server sent an elicitation input request, but no {nameof(McpClientHandlers.ElicitationHandler)} is registered."); + + case RequestMethods.RootsList: + if (_options.Handlers.RootsHandler is { } rootsHandler) + { + // ListRootsRequest params are optional per the spec, so fall back to an empty params instance. + var rootsParams = inputRequest.RootsParams ?? new ListRootsRequestParams(); + var result = await rootsHandler(rootsParams, cancellationToken).ConfigureAwait(false); + return InputResponse.FromRootsResult(result); + } + + throw new InvalidOperationException( + $"Server sent a roots list input request, but no {nameof(McpClientHandlers.RootsHandler)} is registered."); + + default: + throw new NotSupportedException($"Unsupported input request method: '{inputRequest.Method}'."); + } + } + /// /// Asynchronously connects to an MCP server, establishes the transport connection, and completes the initialization handshake. /// @@ -654,6 +760,7 @@ internal void ResumeSession(ResumeClientSessionOptions resumeOptions) LogClientSessionResumed(_endpointName); } + /// /// public override void AddKnownTools(IEnumerable tools) { @@ -718,13 +825,13 @@ public override void ClearKnownTools() } /// - public override Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) + public override async Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) { // For tools/call requests, attach the cached tool definition to the message context // so the transport can add custom Mcp-Param-* headers based on x-mcp-header schema annotations. if (request.Method == RequestMethods.ToolsCall && - request.Params is System.Text.Json.Nodes.JsonObject paramsObj && - paramsObj.TryGetPropertyValue("name", out var nameNode) && + request.Params is System.Text.Json.Nodes.JsonObject paramsObjForHeaders && + paramsObjForHeaders.TryGetPropertyValue("name", out var nameNode) && nameNode?.GetValue() is { } toolName) { if (_toolCache.TryGetValue(toolName, out var tool)) @@ -739,7 +846,61 @@ request.Params is System.Text.Json.Nodes.JsonObject paramsObj && } } - return _sessionHandler.SendRequestAsync(request, cancellationToken); + const int maxRetries = 10; + + for (int attempt = 0; attempt <= maxRetries; attempt++) + { + JsonRpcResponse response = await _sessionHandler.SendRequestAsync(request, cancellationToken).ConfigureAwait(false); + + // Check if the result is an InputRequiredResult by looking at result_type. + if (response.Result is JsonObject resultObj && + resultObj.TryGetPropertyValue("resultType", out var resultTypeNode) && + resultTypeNode?.GetValue() is "input_required") + { + WarnIfInputRequiredResultOnNonMrtrSession(request.Method); + + var inputRequiredResult = JsonSerializer.Deserialize(response.Result, McpJsonUtilities.JsonContext.Default.InputRequiredResult) + ?? throw new JsonException("Failed to deserialize InputRequiredResult."); + + if (inputRequiredResult.InputRequests is { Count: > 0 } inputRequests) + { + IDictionary inputResponses = + await ResolveInputRequestsAsync(inputRequests, cancellationToken).ConfigureAwait(false); + + // Clone the original request params and add inputResponses + requestState for the retry. + var paramsObj = request.Params?.DeepClone() as JsonObject ?? new JsonObject(); + + paramsObj["inputResponses"] = JsonSerializer.SerializeToNode( + inputResponses, McpJsonUtilities.JsonContext.Default.IDictionaryStringInputResponse); + + if (inputRequiredResult.RequestState is { } requestState) + { + paramsObj["requestState"] = requestState; + } + + request = new JsonRpcRequest { Method = request.Method, Params = paramsObj, Context = request.Context }; + } + else if (inputRequiredResult.RequestState is not null) + { + // No input requests but has requestState (e.g., load shedding) — just retry with state. + var paramsObj = request.Params?.DeepClone() as JsonObject ?? new JsonObject(); + paramsObj["requestState"] = inputRequiredResult.RequestState; + paramsObj.Remove("inputResponses"); + + request = new JsonRpcRequest { Method = request.Method, Params = paramsObj, Context = request.Context }; + } + else + { + throw new McpException("Server returned an InputRequiredResult without inputRequests or requestState."); + } + + continue; // retry with the updated request + } + + return response; + } + + throw new McpException($"Server returned InputRequiredResult more than {maxRetries} times."); } /// @@ -775,6 +936,30 @@ public override async ValueTask DisposeAsync() await Completion.ConfigureAwait(false); } + /// Logs a warning if the session negotiated MRTR but the server sent a legacy JSON-RPC request. + private void WarnIfLegacyRequestOnMrtrSession(string method) + { + if (_negotiatedProtocolVersion == McpSessionHandler.DraftProtocolVersion) + { + LogLegacyRequestOnMrtrSession(_endpointName, method); + } + } + + /// Logs a warning if the session did not negotiate MRTR but the server sent an InputRequiredResult. + private void WarnIfInputRequiredResultOnNonMrtrSession(string method) + { + if (_negotiatedProtocolVersion != McpSessionHandler.DraftProtocolVersion) + { + LogInputRequiredResultOnNonMrtrSession(_endpointName, method, _negotiatedProtocolVersion); + } + } + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received legacy '{Method}' JSON-RPC request on session that negotiated MRTR. The server should use InputRequiredResult instead of sending direct requests.")] + private partial void LogLegacyRequestOnMrtrSession(string endpointName, string method); + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received InputRequiredResult for '{Method}' on session that did not negotiate MRTR (protocol version '{ProtocolVersion}'). The server may not be spec-compliant.")] + private partial void LogInputRequiredResultOnNonMrtrSession(string endpointName, string method, string? protocolVersion); + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client received server '{ServerInfo}' capabilities: '{Capabilities}'.")] private partial void LogServerCapabilitiesReceived(string endpointName, string capabilities, string serverInfo); @@ -798,5 +983,4 @@ public override async ValueTask DisposeAsync() [LoggerMessage(Level = LogLevel.Warning, Message = "Tool '{ToolName}' excluded from tools/list: {Reason}")] private partial void LogToolRejected(string toolName, string reason); - } diff --git a/src/ModelContextProtocol.Core/McpJsonUtilities.cs b/src/ModelContextProtocol.Core/McpJsonUtilities.cs index abb6d29df..b4613d9f2 100644 --- a/src/ModelContextProtocol.Core/McpJsonUtilities.cs +++ b/src/ModelContextProtocol.Core/McpJsonUtilities.cs @@ -1,4 +1,4 @@ -using Microsoft.Extensions.AI; +using Microsoft.Extensions.AI; using ModelContextProtocol.Authentication; using ModelContextProtocol.Protocol; using System.Diagnostics.CodeAnalysis; @@ -144,6 +144,13 @@ internal static bool IsValidMcpToolSchema(JsonElement element) [JsonSerializable(typeof(SubscribeRequestParams))] [JsonSerializable(typeof(UnsubscribeRequestParams))] + // MCP MRTR (Multi Round-Trip Requests) + [JsonSerializable(typeof(InputRequiredResult))] + [JsonSerializable(typeof(InputRequest))] + [JsonSerializable(typeof(InputResponse))] + [JsonSerializable(typeof(IDictionary))] + [JsonSerializable(typeof(IDictionary))] + // MCP Task Request Params / Results [JsonSerializable(typeof(McpTask))] [JsonSerializable(typeof(McpTaskStatus))] diff --git a/src/ModelContextProtocol.Core/McpSession.cs b/src/ModelContextProtocol.Core/McpSession.cs index 4201f9833..73d99da71 100644 --- a/src/ModelContextProtocol.Core/McpSession.cs +++ b/src/ModelContextProtocol.Core/McpSession.cs @@ -68,7 +68,7 @@ public abstract partial class McpSession : IAsyncDisposable /// /// The to monitor for cancellation requests. The default is . /// A task that represents the asynchronous send operation. - /// The transport is not connected. + /// The transport is not connected, or is a . Use for requests. /// is . /// /// diff --git a/src/ModelContextProtocol.Core/McpSessionHandler.cs b/src/ModelContextProtocol.Core/McpSessionHandler.cs index 77c18b8be..e874e6724 100644 --- a/src/ModelContextProtocol.Core/McpSessionHandler.cs +++ b/src/ModelContextProtocol.Core/McpSessionHandler.cs @@ -31,6 +31,13 @@ internal sealed partial class McpSessionHandler : IAsyncDisposable /// The latest version of the protocol supported by this implementation. internal const string LatestProtocolVersion = "2025-11-25"; + /// + /// The draft protocol version that enables MRTR (Multi Round-Trip Requests) per SEP-2322. + /// Clients and servers opt in by setting + /// or to this value. + /// + internal const string DraftProtocolVersion = "DRAFT-2026-v1"; + /// /// All protocol versions supported by this implementation. /// Keep in sync with s_supportedProtocolVersions in StreamableHttpHandler. @@ -41,7 +48,7 @@ internal sealed partial class McpSessionHandler : IAsyncDisposable "2025-03-26", "2025-06-18", LatestProtocolVersion, - "DRAFT-2026-v1", + DraftProtocolVersion, ]; /// @@ -642,6 +649,13 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can { Throw.IfNull(message); + if (message is JsonRpcRequest request) + { + throw new InvalidOperationException( + $"Cannot send '{request.Method}' request via {nameof(SendMessageAsync)}. " + + $"Use {nameof(SendRequestAsync)} instead to get a correlated response."); + } + cancellationToken.ThrowIfCancellationRequested(); Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; diff --git a/src/ModelContextProtocol.Core/Protocol/InputRequest.cs b/src/ModelContextProtocol.Core/Protocol/InputRequest.cs new file mode 100644 index 000000000..bd9161423 --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/InputRequest.cs @@ -0,0 +1,197 @@ +using System.Diagnostics.CodeAnalysis; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol; + +/// +/// Represents a server-initiated request that the client must fulfill as part of an MRTR +/// (Multi Round-Trip Request) flow. +/// +/// +/// +/// An wraps a server-to-client request such as +/// , , +/// or . It is included in an +/// when the server needs additional input before it can complete a client-initiated request. +/// +/// +/// The property identifies the type of request, and the corresponding +/// parameters can be accessed via the typed accessor properties. +/// +/// +[Experimental(Experimentals.Mrtr_DiagnosticId, UrlFormat = Experimentals.Mrtr_Url)] +[JsonConverter(typeof(Converter))] +public sealed class InputRequest +{ + /// + /// Gets or sets the method name identifying the type of this input request. + /// + /// + /// Standard values include: + /// + /// A sampling request. + /// An elicitation request. + /// A roots list request. + /// + /// + [JsonPropertyName("method")] + public required string Method { get; set; } + + /// + /// Gets or sets the raw JSON parameters for this input request. + /// + /// + /// Use the typed accessor properties (, , + /// ) for convenient strongly-typed access. + /// + [JsonPropertyName("params")] + public JsonElement? Params { get; set; } + + /// + /// Gets the parameters as when + /// is . + /// + /// The deserialized sampling parameters, or if the method does not match or params are absent. + [JsonIgnore] + public CreateMessageRequestParams? SamplingParams => + string.Equals(Method, RequestMethods.SamplingCreateMessage, StringComparison.Ordinal) && Params is { } p + ? JsonSerializer.Deserialize(p, McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams) + : null; + + /// + /// Gets the parameters as when + /// is . + /// + /// The deserialized elicitation parameters, or if the method does not match or params are absent. + [JsonIgnore] + public ElicitRequestParams? ElicitationParams => + string.Equals(Method, RequestMethods.ElicitationCreate, StringComparison.Ordinal) && Params is { } p + ? JsonSerializer.Deserialize(p, McpJsonUtilities.JsonContext.Default.ElicitRequestParams) + : null; + + /// + /// Gets the parameters as when + /// is . + /// + /// The deserialized roots list parameters, or if the method does not match or params are absent. + [JsonIgnore] + public ListRootsRequestParams? RootsParams => + string.Equals(Method, RequestMethods.RootsList, StringComparison.Ordinal) && Params is { } p + ? JsonSerializer.Deserialize(p, McpJsonUtilities.JsonContext.Default.ListRootsRequestParams) + : null; + + /// + /// Creates an for a sampling request. + /// + /// The sampling request parameters. + /// A new instance. + public static InputRequest ForSampling(CreateMessageRequestParams requestParams) + { + Throw.IfNull(requestParams); + return new() + { + Method = RequestMethods.SamplingCreateMessage, + Params = JsonSerializer.SerializeToElement(requestParams, McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams), + }; + } + + /// + /// Creates an for an elicitation request. + /// + /// The elicitation request parameters. + /// A new instance. + public static InputRequest ForElicitation(ElicitRequestParams requestParams) + { + Throw.IfNull(requestParams); + return new() + { + Method = RequestMethods.ElicitationCreate, + Params = JsonSerializer.SerializeToElement(requestParams, McpJsonUtilities.JsonContext.Default.ElicitRequestParams), + }; + } + + /// + /// Creates an for a roots list request. + /// + /// The roots list request parameters. + /// A new instance. + public static InputRequest ForRootsList(ListRootsRequestParams requestParams) + { + Throw.IfNull(requestParams); + return new() + { + Method = RequestMethods.RootsList, + Params = JsonSerializer.SerializeToElement(requestParams, McpJsonUtilities.JsonContext.Default.ListRootsRequestParams), + }; + } + + /// Provides JSON serialization support for . + public sealed class Converter : JsonConverter + { + /// + public override InputRequest? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType != JsonTokenType.StartObject) + { + throw new JsonException("Expected StartObject token."); + } + + string? method = null; + JsonElement? parameters = null; + + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndObject) + { + break; + } + + if (reader.TokenType != JsonTokenType.PropertyName) + { + throw new JsonException("Expected PropertyName token."); + } + + string propertyName = reader.GetString()!; + reader.Read(); + + switch (propertyName) + { + case "method": + method = reader.GetString(); + break; + case "params": + parameters = JsonElement.ParseValue(ref reader); + break; + default: + reader.Skip(); + break; + } + } + + if (method is null) + { + throw new JsonException("InputRequest must have a 'method' property."); + } + + return new InputRequest + { + Method = method, + Params = parameters, + }; + } + + /// + public override void Write(Utf8JsonWriter writer, InputRequest value, JsonSerializerOptions options) + { + writer.WriteStartObject(); + writer.WriteString("method", value.Method); + if (value.Params is { } p) + { + writer.WritePropertyName("params"); + p.WriteTo(writer); + } + writer.WriteEndObject(); + } + } +} diff --git a/src/ModelContextProtocol.Core/Protocol/InputRequiredException.cs b/src/ModelContextProtocol.Core/Protocol/InputRequiredException.cs new file mode 100644 index 000000000..4f39b17a5 --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/InputRequiredException.cs @@ -0,0 +1,109 @@ +using System.Diagnostics.CodeAnalysis; + +namespace ModelContextProtocol.Protocol; + +/// +/// The exception that is thrown by a server handler to return an +/// to the client, signaling that additional input is needed before the request can be completed. +/// +/// +/// +/// This exception is part of the Multi Round-Trip Requests (MRTR) API. Tool handlers +/// throw this exception to directly control the input-required result payload, including +/// and . +/// +/// +/// For stateless servers, this enables multi-round-trip flows without requiring the handler to stay +/// alive between round trips. The server encodes its state in +/// and receives it back on retry via . +/// +/// +/// To return a requestState-only response (e.g., for load shedding), omit +/// and set only . +/// The client will retry the request with the state echoed back. +/// +/// +/// This exception can only be used when MRTR is supported by the client. Check +/// before throwing. If thrown when MRTR is not +/// supported, the exception will propagate as a JSON-RPC internal error. +/// +/// +/// +/// +/// [McpServerTool, Description("A stateless tool using MRTR")] +/// public static string MyTool(McpServer server, RequestContext<CallToolRequestParams> context) +/// { +/// if (context.Params.RequestState is { } state) +/// { +/// // Retry: process accumulated state and input responses +/// var responses = context.Params.InputResponses; +/// return "Final result"; +/// } +/// +/// if (!server.IsMrtrSupported) +/// { +/// return "This tool requires MRTR support."; +/// } +/// +/// throw new InputRequiredException( +/// inputRequests: new Dictionary<string, InputRequest> +/// { +/// ["user_input"] = InputRequest.ForElicitation(new ElicitRequestParams { ... }) +/// }, +/// requestState: "encoded-state"); +/// } +/// +/// +[Experimental(Experimentals.Mrtr_DiagnosticId, UrlFormat = Experimentals.Mrtr_Url)] +public class InputRequiredException : Exception +{ + /// + /// Initializes a new instance of the class + /// with the specified . + /// + /// The input-required result to return to the client. + public InputRequiredException(InputRequiredResult result) + : base("The server returned an input-required result requiring additional client input.") + { + Throw.IfNull(result); + Result = result; + } + + /// + /// Initializes a new instance of the class + /// with the specified input requests and/or request state. + /// + /// + /// Server-initiated requests that the client must fulfill before retrying. + /// Keys are server-assigned identifiers. + /// + /// + /// Opaque state to be echoed back by the client when retrying. The client must + /// treat this as an opaque blob and must not inspect or modify it. + /// + /// + /// Both and are . + /// At least one must be provided. + /// + public InputRequiredException( + IDictionary? inputRequests = null, + string? requestState = null) + : base("The server returned an input-required result requiring additional client input.") + { + if (inputRequests is null && requestState is null) + { + throw new ArgumentException("At least one of inputRequests or requestState must be provided."); + } + + Result = new InputRequiredResult + { + InputRequests = inputRequests, + RequestState = requestState, + }; + } + + /// + /// Gets the input-required result to return to the client. + /// + public InputRequiredResult Result { get; } +} diff --git a/src/ModelContextProtocol.Core/Protocol/InputRequiredResult.cs b/src/ModelContextProtocol.Core/Protocol/InputRequiredResult.cs new file mode 100644 index 000000000..1391e2ba9 --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/InputRequiredResult.cs @@ -0,0 +1,67 @@ +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol; + +/// +/// Represents an input-required result sent by the server to indicate that additional input is needed +/// before the request can be completed. +/// +/// +/// +/// An is returned in response to a client-initiated request when +/// the server needs the client to fulfill one or more server-initiated requests before it can produce +/// a final result. Per SEP-2322 the wire format is valid for , +/// , and resources/read, but this SDK currently only wires +/// the MRTR interceptor into ; throwing +/// from a prompts or resources handler will surface as an internal +/// error until the other methods are opted in. +/// +/// +/// At least one of or must be present. +/// +/// +/// This type is part of the Multi Round-Trip Requests (MRTR) mechanism defined in SEP-2322. +/// +/// +[Experimental(Experimentals.Mrtr_DiagnosticId, UrlFormat = Experimentals.Mrtr_Url)] +public sealed class InputRequiredResult : Result +{ + /// + /// Initializes a new instance of the class. + /// + public InputRequiredResult() + { + ResultType = "input_required"; + } + + /// + /// Gets or sets the server-initiated requests that the client must fulfill before retrying the original request. + /// + /// + /// + /// The keys are server-assigned identifiers. The client must include a response for each key in the + /// map when retrying the original request. + /// + /// + [JsonPropertyName("inputRequests")] + public IDictionary? InputRequests { get; set; } + + /// + /// Gets or sets opaque state to be echoed back by the client when retrying the original request. + /// + /// + /// + /// The client must treat this as an opaque blob and must not inspect, parse, modify, or make + /// any assumptions about the contents. If present, the client must include this value in the + /// property when retrying the original request. + /// + /// + /// Servers may encode request state in any format (e.g., plain JSON, base64-encoded JSON, + /// encrypted JWT, serialized binary). If the state contains sensitive data, servers should + /// encrypt it to ensure confidentiality and integrity. + /// + /// + [JsonPropertyName("requestState")] + public string? RequestState { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Protocol/InputResponse.cs b/src/ModelContextProtocol.Core/Protocol/InputResponse.cs new file mode 100644 index 000000000..79ac22dc9 --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/InputResponse.cs @@ -0,0 +1,128 @@ +using System.Diagnostics.CodeAnalysis; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Text.Json.Serialization.Metadata; + +namespace ModelContextProtocol.Protocol; + +/// +/// Represents a client's response to a server-initiated as part of an MRTR +/// (Multi Round-Trip Request) flow. +/// +/// +/// +/// An wraps the result of a server-to-client request such as +/// , , or . +/// The type of the inner response corresponds to the of the +/// associated input request. +/// +/// +/// The input response does not carry its own type discriminator in JSON. The type is determined by +/// the corresponding key in the map. +/// +/// +[Experimental(Experimentals.Mrtr_DiagnosticId, UrlFormat = Experimentals.Mrtr_Url)] +[JsonConverter(typeof(Converter))] +public sealed class InputResponse +{ + /// + /// Gets or sets the raw JSON element representing the response. + /// + /// + /// Use with the JsonTypeInfo<T> matching the + /// associated — for elicitation, sampling, or roots see + /// , , and + /// . + /// + [JsonIgnore] + public JsonElement RawValue { get; set; } + + /// + /// Deserializes the raw value to the specified result type. + /// + /// The type to deserialize to (e.g., , ). + /// The JSON type information for . + /// The deserialized result, or if deserialization fails. + public T? Deserialize(System.Text.Json.Serialization.Metadata.JsonTypeInfo typeInfo) => + JsonSerializer.Deserialize(RawValue, typeInfo); + + /// + /// Gets the for , suitable for use with + /// when the corresponding is + /// . + /// + public static JsonTypeInfo ElicitResultJsonTypeInfo => McpJsonUtilities.JsonContext.Default.ElicitResult; + + /// + /// Gets the for , suitable for use with + /// when the corresponding is + /// . + /// + public static JsonTypeInfo CreateMessageResultJsonTypeInfo => McpJsonUtilities.JsonContext.Default.CreateMessageResult; + + /// + /// Gets the for , suitable for use with + /// when the corresponding is + /// . + /// + public static JsonTypeInfo ListRootsResultJsonTypeInfo => McpJsonUtilities.JsonContext.Default.ListRootsResult; + + /// + /// Creates an from a . + /// + /// The sampling result. + /// A new instance. + public static InputResponse FromSamplingResult(CreateMessageResult result) + { + Throw.IfNull(result); + return new() + { + RawValue = JsonSerializer.SerializeToElement(result, McpJsonUtilities.JsonContext.Default.CreateMessageResult), + }; + } + + /// + /// Creates an from an . + /// + /// The elicitation result. + /// A new instance. + public static InputResponse FromElicitResult(ElicitResult result) + { + Throw.IfNull(result); + return new() + { + RawValue = JsonSerializer.SerializeToElement(result, McpJsonUtilities.JsonContext.Default.ElicitResult), + }; + } + + /// + /// Creates an from a . + /// + /// The roots list result. + /// A new instance. + public static InputResponse FromRootsResult(ListRootsResult result) + { + Throw.IfNull(result); + return new() + { + RawValue = JsonSerializer.SerializeToElement(result, McpJsonUtilities.JsonContext.Default.ListRootsResult), + }; + } + + /// Provides JSON serialization support for . + public sealed class Converter : JsonConverter + { + /// + public override InputResponse? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + var element = JsonElement.ParseValue(ref reader); + return new InputResponse { RawValue = element }; + } + + /// + public override void Write(Utf8JsonWriter writer, InputResponse value, JsonSerializerOptions options) + { + value.RawValue.WriteTo(writer); + } + } +} diff --git a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs index 2fa9839f0..e5c0f3931 100644 --- a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs +++ b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs @@ -74,4 +74,15 @@ public sealed class JsonRpcMessageContext /// /// public IDictionary? Items { get; set; } + + /// + /// Gets or sets the protocol version from the transport-level header (e.g. Mcp-Protocol-Version) + /// that accompanied this JSON-RPC message. + /// + /// + /// In stateless Streamable HTTP mode, the protocol version cannot be negotiated via the initialize + /// handshake because each request creates a new server instance. This property allows the transport layer + /// to flow the protocol version header so the server can determine client capabilities. + /// + public string? ProtocolVersion { get; set; } } diff --git a/src/ModelContextProtocol.Core/Protocol/RequestParams.cs b/src/ModelContextProtocol.Core/Protocol/RequestParams.cs index 0a0586a71..004f1711f 100644 --- a/src/ModelContextProtocol.Core/Protocol/RequestParams.cs +++ b/src/ModelContextProtocol.Core/Protocol/RequestParams.cs @@ -1,3 +1,4 @@ +using System.Diagnostics.CodeAnalysis; using System.Text.Json.Nodes; using System.Text.Json.Serialization; @@ -25,6 +26,52 @@ private protected RequestParams() [JsonPropertyName("_meta")] public JsonObject? Meta { get; set; } + /// + /// Gets or sets the responses to server-initiated input requests from a previous . + /// + /// + /// + /// This property is populated when retrying a request after receiving an . + /// Each key corresponds to a key from the map, and + /// the value is the client's response to that input request. + /// + /// + [Experimental(Experimentals.Mrtr_DiagnosticId, UrlFormat = Experimentals.Mrtr_Url)] + [JsonIgnore] + public IDictionary? InputResponses + { + get => InputResponsesCore; + set => InputResponsesCore = value; + } + + // See ExperimentalInternalPropertyTests.cs before modifying this property. + [JsonInclude] + [JsonPropertyName("inputResponses")] + internal IDictionary? InputResponsesCore { get; set; } + + /// + /// Gets or sets opaque request state echoed back from a previous . + /// + /// + /// + /// This property is populated when retrying a request after receiving an + /// that included a value. The client must echo back the + /// exact value without modification. + /// + /// + [Experimental(Experimentals.Mrtr_DiagnosticId, UrlFormat = Experimentals.Mrtr_Url)] + [JsonIgnore] + public string? RequestState + { + get => RequestStateCore; + set => RequestStateCore = value; + } + + // See ExperimentalInternalPropertyTests.cs before modifying this property. + [JsonInclude] + [JsonPropertyName("requestState")] + internal string? RequestStateCore { get; set; } + /// /// Gets the opaque token that will be attached to any subsequent progress notifications. /// diff --git a/src/ModelContextProtocol.Core/Protocol/Result.cs b/src/ModelContextProtocol.Core/Protocol/Result.cs index 58b076ddb..6e43249a1 100644 --- a/src/ModelContextProtocol.Core/Protocol/Result.cs +++ b/src/ModelContextProtocol.Core/Protocol/Result.cs @@ -21,4 +21,18 @@ private protected Result() /// [JsonPropertyName("_meta")] public JsonObject? Meta { get; set; } + + /// + /// Gets or sets the type of the result, which allows the client to determine how to parse the result object. + /// + /// + /// + /// When absent or set to "complete", the result is a normal completed response. + /// When set to "input_required", the result is an indicating + /// that additional input is needed before the request can be completed. + /// + /// + /// Defaults to , which is equivalent to "complete". + [JsonPropertyName("resultType")] + public string? ResultType { get; set; } } diff --git a/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs b/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs index 957f58a51..6599fb0b4 100644 --- a/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs +++ b/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs @@ -1,5 +1,4 @@ -using ModelContextProtocol.Protocol; -using System.Diagnostics; +using ModelContextProtocol.Protocol; namespace ModelContextProtocol.Server; @@ -15,6 +14,8 @@ internal sealed class DestinationBoundMcpServer(McpServerImpl server, ITransport public override IServiceProvider? Services => server.Services; public override LoggingLevel? LoggingLevel => server.LoggingLevel; + public override bool IsMrtrSupported => server.IsMrtrSupported; + public override ValueTask DisposeAsync() => server.DisposeAsync(); public override IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => server.RegisterNotificationHandler(method, handler); diff --git a/src/ModelContextProtocol.Core/Server/McpServer.cs b/src/ModelContextProtocol.Core/Server/McpServer.cs index b8b41bdc3..444365361 100644 --- a/src/ModelContextProtocol.Core/Server/McpServer.cs +++ b/src/ModelContextProtocol.Core/Server/McpServer.cs @@ -64,6 +64,25 @@ protected McpServer() /// Gets the last logging level set by the client, or if it's never been set. public abstract LoggingLevel? LoggingLevel { get; } + /// + /// Gets a value indicating whether the connected client supports Multi Round-Trip Requests (MRTR). + /// + /// + /// + /// When this property returns , tool handlers can throw + /// to return an + /// with and/or + /// to the client. + /// + /// + /// When this property returns , tool handlers should provide a fallback + /// experience (for example, returning a text message explaining that the client does not support + /// the required feature) instead of throwing . + /// + /// + [Experimental(Experimentals.Mrtr_DiagnosticId, UrlFormat = Experimentals.Mrtr_Url)] + public virtual bool IsMrtrSupported => false; + /// /// Runs the server, listening for and handling client requests. /// diff --git a/src/ModelContextProtocol.Core/Server/McpServerImpl.cs b/src/ModelContextProtocol.Core/Server/McpServerImpl.cs index 203856814..8fd118b6b 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerImpl.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerImpl.cs @@ -4,6 +4,7 @@ using ModelContextProtocol.Protocol; using System.Runtime.CompilerServices; using System.Text.Json; +using System.Text.Json.Nodes; using System.Text.Json.Serialization.Metadata; namespace ModelContextProtocol.Server; @@ -231,7 +232,8 @@ private void ConfigureInitialize(McpServerOptions options) // Otherwise, try to use whatever the client requested as long as it's supported. // If it's not supported, fall back to the latest supported version. string? protocolVersion = options.ProtocolVersion; - protocolVersion ??= request?.ProtocolVersion is string clientProtocolVersion && McpSessionHandler.SupportedProtocolVersions.Contains(clientProtocolVersion) ? + protocolVersion ??= request?.ProtocolVersion is string clientProtocolVersion && + McpSessionHandler.SupportedProtocolVersions.Contains(clientProtocolVersion) ? clientProtocolVersion : McpSessionHandler.LatestProtocolVersion; @@ -725,7 +727,7 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) McpErrorCode.InvalidParams); } - // Task augmentation requested - return CreateTaskResult + // Task augmentation requested with immediate creation return await ExecuteToolAsTaskAsync(tool, request, taskMetadata, taskStore, sendNotifications, cancellationToken).ConfigureAwait(false); } @@ -774,9 +776,18 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) } catch (Exception e) { - ToolCallError(request.Params?.Name ?? string.Empty, e); + // Skip logging for OperationCanceledException when the cancellation token + // is signaled — tool handler cancellation is an expected lifecycle event + // (client request cancellation, session shutdown, MRTR teardown), not a + // tool error. + // Skip logging for InputRequiredException — it's normal MRTR control flow, + // not an error (tools throw it to signal an InputRequiredResult). + if (!(e is OperationCanceledException && cancellationToken.IsCancellationRequested) && e is not InputRequiredException) + { + ToolCallError(request.Params?.Name ?? string.Empty, e); + } - if ((e is OperationCanceledException && cancellationToken.IsCancellationRequested) || e is McpProtocolException) + if ((e is OperationCanceledException && cancellationToken.IsCancellationRequested) || e is McpProtocolException || e is InputRequiredException) { throw; } @@ -990,7 +1001,7 @@ private ValueTask InvokeHandlerAsync( { return _servicesScopePerRequest ? InvokeScopedAsync(handler, args, jsonRpcRequest, cancellationToken) : - handler(new(new DestinationBoundMcpServer(this, jsonRpcRequest.Context?.RelatedTransport), jsonRpcRequest, args), cancellationToken); + handler(new(CreateDestinationBoundServer(jsonRpcRequest), jsonRpcRequest, args), cancellationToken); async ValueTask InvokeScopedAsync( McpRequestHandler handler, @@ -1002,7 +1013,7 @@ async ValueTask InvokeScopedAsync( try { return await handler( - new RequestContext(new DestinationBoundMcpServer(this, jsonRpcRequest.Context?.RelatedTransport), jsonRpcRequest, args) + new RequestContext(CreateDestinationBoundServer(jsonRpcRequest), jsonRpcRequest, args) { Services = scope?.ServiceProvider ?? Services, }, @@ -1018,6 +1029,9 @@ async ValueTask InvokeScopedAsync( } } + private DestinationBoundMcpServer CreateDestinationBoundServer(JsonRpcRequest jsonRpcRequest) => + new(this, jsonRpcRequest.Context?.RelatedTransport); + private void SetHandler( string method, McpRequestHandler handler, @@ -1028,6 +1042,13 @@ private void SetHandler( (request, jsonRpcRequest, cancellationToken) => InvokeHandlerAsync(handler, request, jsonRpcRequest, cancellationToken), requestTypeInfo, responseTypeInfo); + + if (method == RequestMethods.ToolsCall) + { + var originalHandler = _requestHandlers[method]; + _requestHandlers[method] = (request, cancellationToken) => + InvokeWithInputRequiredResultHandlingAsync(originalHandler, request, cancellationToken); + } } private static McpRequestHandler BuildFilterPipeline( @@ -1106,6 +1127,200 @@ internal static LoggingLevel ToLoggingLevel(LogLevel level) => _ => Protocol.LoggingLevel.Emergency, }; + /// + /// Checks whether the negotiated protocol version enables MRTR per SEP-2322 (DRAFT-2026-v1). + /// + internal bool ClientSupportsMrtr() => + _negotiatedProtocolVersion == McpSessionHandler.DraftProtocolVersion; + + /// + /// Returns when the session is stateful — the same server instance handles + /// subsequent requests on the same session. The legacy backcompat resolver in + /// needs a stateful session so it can send + /// elicitation/create / sampling/createMessage / roots/list to the client and + /// retry the handler with the responses. + /// + internal bool IsStatefulSession() => + _sessionTransport is not StreamableHttpServerTransport { Stateless: true }; + + /// + public override bool IsMrtrSupported => ClientSupportsMrtr() || IsStatefulSession(); + + /// + /// Invokes a handler and catches to convert it to an + /// JSON response. When MRTR is negotiated or the server is stateless, + /// the result is serialized directly. Otherwise, input requests are resolved via standard JSON-RPC + /// calls (elicitation, sampling, roots) and the handler is retried with the responses — allowing + /// MRTR-native tools to work transparently with clients that don't support MRTR. + /// + private async Task InvokeWithInputRequiredResultHandlingAsync( + Func> handler, + JsonRpcRequest request, + CancellationToken cancellationToken) + { + const int MaxRetries = 10; + + // In stateless mode, pick up the negotiated draft protocol version from the + // transport-provided request context because there is no long-lived initialize handshake state. + if (_negotiatedProtocolVersion is null && + request.Context?.ProtocolVersion is { } headerProtocolVersion) + { + _negotiatedProtocolVersion = headerProtocolVersion; + } + + for (int retry = 0; ; retry++) + { + try + { + return await handler(request, cancellationToken).ConfigureAwait(false); + } + catch (InputRequiredException ex) + { + // If the client natively supports MRTR, serialize and return directly — + // the client will drive the retry loop. + if (ClientSupportsMrtr()) + { + return SerializeInputRequiredResult(ex.Result); + } + + // In stateless mode without MRTR, the server can't resolve input requests via + // JSON-RPC (no persistent session for server-to-client requests), and the client + // won't recognize the InputRequiredResult. This is the one unsupported configuration. + // TODO(stateless-draft): When DRAFT-2026-v1 becomes stateless-only, the IsStatefulSession() gate collapses — the stateful path will only matter for legacy clients on the current protocol. + if (!IsStatefulSession()) + { + throw new McpException( + "A tool handler returned an incomplete result, but the server is stateless and the client does not support MRTR. " + + "MRTR-native tools require either an MRTR-capable client or a stateful server for backward-compatible resolution.", ex); + } + + // Backcompat: resolve input requests via standard JSON-RPC calls and retry the handler. + if (ex.Result.InputRequests is not { Count: > 0 } inputRequests) + { + throw new McpException( + "A tool handler returned an incomplete result without input requests, and the client does not support MRTR.", ex); + } + + if (retry >= MaxRetries) + { + throw new McpException( + $"MRTR-native tool exceeded {MaxRetries} retry rounds without completing.", ex); + } + + // Resolve each input request by sending the corresponding JSON-RPC call to the client. + // Route the outgoing requests via the same DestinationBoundMcpServer used for normal tool + // handlers, so they go through the POST's response stream (RelatedTransport) rather than + // the session-level transport. Without this, the messages can race with the client's GET + // stream startup and be silently dropped by StreamableHttpServerTransport.SendMessageAsync + // when no GET request has arrived yet. + var destinationServer = CreateDestinationBoundServer(request); + var inputResponses = await ResolveInputRequestsAsync(destinationServer, inputRequests, cancellationToken).ConfigureAwait(false); + + // Reconstruct request params with inputResponses and requestState for the retry. + var paramsObj = request.Params?.DeepClone() as JsonObject ?? new JsonObject(); + paramsObj["inputResponses"] = JsonSerializer.SerializeToNode( + (IDictionary)inputResponses, McpJsonUtilities.JsonContext.Default.IDictionaryStringInputResponse); + + if (ex.Result.RequestState is { } requestState) + { + paramsObj["requestState"] = requestState; + } + + request = new JsonRpcRequest + { + Id = request.Id, + Method = request.Method, + Params = paramsObj, + Context = request.Context, + }; + } + } + } + + /// + /// Resolves a batch of MRTR input requests concurrently by dispatching each as a standard + /// JSON-RPC request to the client. The requests are routed via + /// so they go out through the POST's response stream (matching the behavior of tool-initiated + /// server-to-client requests like server.SampleAsync) and avoid racing with the client's + /// GET stream startup. On the first failure all remaining handlers are cancelled so user-facing + /// flows (sampling/elicitation prompts) don't keep running once the caller has given up, and + /// exceptions from late-completing tasks are observed before the original exception is rethrown. + /// + private static async Task> ResolveInputRequestsAsync( + McpServer destinationServer, + IDictionary inputRequests, + CancellationToken cancellationToken) + { + using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + + var keyed = new (string Key, Task Task)[inputRequests.Count]; + int i = 0; + foreach (var kvp in inputRequests) + { + keyed[i++] = (kvp.Key, ResolveInputRequestAsync(destinationServer, kvp.Value, linkedCts.Token)); + } + + try + { + await Task.WhenAll(Array.ConvertAll(keyed, k => k.Task)).ConfigureAwait(false); + } + catch + { + linkedCts.Cancel(); + try + { + await Task.WhenAll(Array.ConvertAll(keyed, k => k.Task)).ConfigureAwait(false); + } + catch + { + // Observed; the original exception is the one we want to surface. + } + throw; + } + + var responses = new Dictionary(keyed.Length); + foreach (var (key, task) in keyed) + { + responses[key] = task.Result; + } + return responses; + } + + /// + /// Resolves a single MRTR by dispatching it as a standard JSON-RPC + /// request to the client via . This is the server-side mirror + /// of the client's input resolution logic, used for backward compatibility when the client doesn't + /// support MRTR. + /// + private static async Task ResolveInputRequestAsync(McpServer destinationServer, InputRequest inputRequest, CancellationToken cancellationToken) + { + switch (inputRequest.Method) + { + case RequestMethods.ElicitationCreate: + var elicitParams = inputRequest.ElicitationParams + ?? throw new McpException("Failed to deserialize elicitation parameters from MRTR input request."); + var elicitResult = await destinationServer.ElicitAsync(elicitParams, cancellationToken).ConfigureAwait(false); + return InputResponse.FromElicitResult(elicitResult); + + case RequestMethods.SamplingCreateMessage: + var samplingParams = inputRequest.SamplingParams + ?? throw new McpException("Failed to deserialize sampling parameters from MRTR input request."); + var samplingResult = await destinationServer.SampleAsync(samplingParams, cancellationToken).ConfigureAwait(false); + return InputResponse.FromSamplingResult(samplingResult); + + case RequestMethods.RootsList: + var rootsParams = inputRequest.RootsParams ?? new ListRootsRequestParams(); + var rootsResult = await destinationServer.RequestRootsAsync(rootsParams, cancellationToken).ConfigureAwait(false); + return InputResponse.FromRootsResult(rootsResult); + + default: + throw new McpException($"Unsupported input request method: '{inputRequest.Method}'."); + } + } + + private static JsonNode? SerializeInputRequiredResult(InputRequiredResult inputRequiredResult) => + JsonSerializer.SerializeToNode(inputRequiredResult, McpJsonUtilities.JsonContext.Default.InputRequiredResult); + [LoggerMessage(Level = LogLevel.Error, Message = "\"{ToolName}\" threw an unhandled exception.")] private partial void ToolCallError(string toolName, Exception exception); diff --git a/tests/Common/Utils/NodeHelpers.cs b/tests/Common/Utils/NodeHelpers.cs index a30dd3fc3..ef1686abb 100644 --- a/tests/Common/Utils/NodeHelpers.cs +++ b/tests/Common/Utils/NodeHelpers.cs @@ -205,6 +205,44 @@ public static bool HasSep2243Scenarios() } } + /// + /// Checks whether the SEP-2322 (Multi Round-Trip Requests / IncompleteResult) + /// conformance scenarios are available by reading the conformance package version + /// from the repo's package.json. MRTR scenarios require a conformance package version + /// that includes SEP-2322 support (see + /// https://github.com/modelcontextprotocol/conformance/pull/188). + /// + public static bool HasMrtrScenarios() + { + try + { + var repoRoot = FindRepoRoot(); + var packageJsonPath = Path.Combine(repoRoot, "package.json"); + if (!File.Exists(packageJsonPath)) + { + return false; + } + + var json = System.Text.Json.JsonDocument.Parse(File.ReadAllText(packageJsonPath)); + if (json.RootElement.TryGetProperty("dependencies", out var deps) && + deps.TryGetProperty("@modelcontextprotocol/conformance", out var versionElement)) + { + var versionStr = versionElement.GetString(); + if (versionStr is not null && Version.TryParse(versionStr, out var version)) + { + // SEP-2322 scenarios are expected in conformance package >= 0.2.0 + return version >= new Version(0, 2, 0); + } + } + + return false; + } + catch + { + return false; + } + } + private static ProcessStartInfo NpmStartInfo(string arguments, string workingDirectory) { if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) diff --git a/tests/Common/Utils/ServerMessageTracker.cs b/tests/Common/Utils/ServerMessageTracker.cs new file mode 100644 index 000000000..66a80c681 --- /dev/null +++ b/tests/Common/Utils/ServerMessageTracker.cs @@ -0,0 +1,95 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.Collections.Concurrent; +using System.Text.Json.Nodes; +using Xunit; + +namespace ModelContextProtocol.Tests.Utils; + +/// +/// Tracks MRTR protocol mode via incoming and outgoing message filters. +/// Used by MRTR tests to verify the correct protocol mode (MRTR vs legacy) was used. +/// +internal sealed class ServerMessageTracker +{ + private static readonly HashSet LegacyMrtrMethods = + [ + RequestMethods.ElicitationCreate, + RequestMethods.SamplingCreateMessage, + RequestMethods.RootsList, + ]; + + private readonly ConcurrentBag _legacyRequestMethods = []; + private int _mrtrRetryCount; + private int _incompleteResultCount; + + /// + /// Adds incoming and outgoing message filters to track MRTR protocol usage. + /// Call this in services.Configure<McpServerOptions> or AddMcpServer callbacks. + /// + public void AddFilters(McpMessageFilters messageFilters) + { + // Track outgoing legacy JSON-RPC requests and InputRequiredResult responses. + messageFilters.OutgoingFilters.Add(next => async (context, cancellationToken) => + { + if (context.JsonRpcMessage is JsonRpcRequest request && LegacyMrtrMethods.Contains(request.Method)) + { + _legacyRequestMethods.Add(request.Method); + } + else if (context.JsonRpcMessage is JsonRpcResponse response && + response.Result is JsonObject resultObj && + resultObj.TryGetPropertyValue("resultType", out var resultTypeNode) && + resultTypeNode?.GetValue() == "input_required") + { + Interlocked.Increment(ref _incompleteResultCount); + } + + await next(context, cancellationToken); + }); + + // Track incoming MRTR retries (requests with inputResponses or requestState in params). + messageFilters.IncomingFilters.Add(next => async (context, cancellationToken) => + { + if (context.JsonRpcMessage is JsonRpcRequest request && + request.Params is JsonObject paramsObj && + (paramsObj.ContainsKey("inputResponses") || paramsObj.ContainsKey("requestState"))) + { + Interlocked.Increment(ref _mrtrRetryCount); + } + + await next(context, cancellationToken); + }); + } + + /// + /// Asserts that MRTR was used: at least one InputRequiredResult response was sent + /// and no legacy JSON-RPC requests (elicitation/create, sampling/createMessage, roots/list) were sent. + /// + public void AssertMrtrUsed() + { + Assert.True(_incompleteResultCount > 0, + "Expected at least one InputRequiredResult response (MRTR mode), but none were detected."); + Assert.Empty(_legacyRequestMethods); + } + + /// + /// Asserts that MRTR was used at least once (at least one InputRequiredResult response was sent), + /// independent of whether the session also issued any legacy server-to-client requests. + /// + public void AssertMrtrUsedAtLeastOnce() + { + Assert.True(_incompleteResultCount > 0, + "Expected at least one InputRequiredResult response (MRTR mode), but none were detected."); + } + + /// + /// Asserts that legacy mode was used: at least one legacy JSON-RPC request was sent + /// and no MRTR retries or InputRequiredResult responses were detected. + /// + public void AssertMrtrNotUsed() + { + Assert.NotEmpty(_legacyRequestMethods); + Assert.Equal(0, _mrtrRetryCount); + Assert.Equal(0, _incompleteResultCount); + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStatelessTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStatelessTests.cs index 5552b5395..ac19953bf 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStatelessTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStatelessTests.cs @@ -1,7 +1,45 @@ -namespace ModelContextProtocol.AspNetCore.Tests; +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.AspNetCore.Tests; public class MapMcpStatelessTests(ITestOutputHelper outputHelper) : MapMcpStreamableHttpTests(outputHelper) { protected override bool UseStreamableHttp => true; protected override bool Stateless => true; + + [Fact] + public async Task EnablePollingAsync_ThrowsInvalidOperationException_InStatelessMode() + { + InvalidOperationException? capturedException = null; + var pollingTool = McpServerTool.Create(async (RequestContext context) => + { + try + { + await context.EnablePollingAsync(retryInterval: TimeSpan.FromSeconds(1)); + } + catch (InvalidOperationException ex) + { + capturedException = ex; + } + + return "Complete"; + }, options: new() { Name = "polling_tool" }); + + Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless).WithTools([pollingTool]); + + await using var app = Builder.Build(); + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var mcpClient = await ConnectAsync(); + + await mcpClient.CallToolAsync("polling_tool", cancellationToken: TestContext.Current.CancellationToken); + + Assert.NotNull(capturedException); + Assert.Contains("stateless", capturedException.Message, StringComparison.OrdinalIgnoreCase); + } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs index 4f2d5aaeb..b95ea67a6 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs @@ -347,9 +347,9 @@ public async Task StreamableHttpClient_SendsMcpProtocolVersionHeader_AfterInitia await app.StartAsync(TestContext.Current.CancellationToken); - await using var mcpClient = await ConnectAsync(clientOptions: new() + await using var mcpClient = await ConnectAsync(configureClient: options => { - ProtocolVersion = "2025-06-18", + options.ProtocolVersion = "2025-06-18"; }); Assert.Equal("2025-06-18", mcpClient.NegotiatedProtocolVersion); @@ -457,41 +457,6 @@ public async Task CanResumeSessionWithMapMcpAndRunSessionHandler() Assert.Equal(1, runSessionCount); } - [Fact] - public async Task EnablePollingAsync_ThrowsInvalidOperationException_InStatelessMode() - { - Assert.SkipUnless(Stateless, "This test only applies to stateless mode."); - - InvalidOperationException? capturedException = null; - var pollingTool = McpServerTool.Create(async (RequestContext context) => - { - try - { - await context.EnablePollingAsync(retryInterval: TimeSpan.FromSeconds(1)); - } - catch (InvalidOperationException ex) - { - capturedException = ex; - } - - return "Complete"; - }, options: new() { Name = "polling_tool" }); - - Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless).WithTools([pollingTool]); - - await using var app = Builder.Build(); - app.MapMcp(); - - await app.StartAsync(TestContext.Current.CancellationToken); - - await using var mcpClient = await ConnectAsync(); - - await mcpClient.CallToolAsync("polling_tool", cancellationToken: TestContext.Current.CancellationToken); - - Assert.NotNull(capturedException); - Assert.Contains("stateless", capturedException.Message, StringComparison.OrdinalIgnoreCase); - } - [Fact] public async Task EnablePollingAsync_ThrowsInvalidOperationException_WhenNoEventStreamStoreConfigured() { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.Mrtr.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.Mrtr.cs new file mode 100644 index 000000000..a6350e5cc --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.Mrtr.cs @@ -0,0 +1,746 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; +using System.Text.Json; + +namespace ModelContextProtocol.AspNetCore.Tests; + +public abstract partial class MapMcpTests +{ + private ServerMessageTracker ConfigureServer(params Delegate[] tools) + { + var messageTracker = new ServerMessageTracker(); + Builder.Services.AddMcpServer(options => + { + options.ServerInfo = new Implementation { Name = "MrtrTestServer", Version = "1" }; + // Do not pin a protocol version — let it be negotiated based on what the client requests. + // DRAFT-2026-v1 is in SupportedProtocolVersions, so an opt-in client gets it; others get + // the latest non-draft. + messageTracker.AddFilters(options.Filters.Message); + }) + .WithHttpTransport(ConfigureStateless) + .WithTools(tools.Select(t => McpServerTool.Create(t))); + return messageTracker; + } + + private Task ConnectExperimentalAsync() => + ConnectAsync(configureClient: options => + { + ConfigureMrtrHandlers(options); + options.ProtocolVersion = "DRAFT-2026-v1"; + }); + + private Task ConnectDefaultAsync() => + ConnectAsync(configureClient: ConfigureMrtrHandlers); + + /// Configures elicitation, sampling, and roots handlers on client options. + private static void ConfigureMrtrHandlers(McpClientOptions options) + { + options.Handlers.ElicitationHandler = (request, ct) => + { + var message = request?.Message ?? ""; + var answer = message.Contains("name", StringComparison.OrdinalIgnoreCase) ? "Alice" + : message.Contains("greet", StringComparison.OrdinalIgnoreCase) ? "Hello" + : "yes"; + + return new ValueTask(new ElicitResult + { + Action = "accept", + Content = new Dictionary + { + ["answer"] = JsonDocument.Parse($"\"{answer}\"").RootElement.Clone() + } + }); + }; + options.Handlers.SamplingHandler = (request, progress, ct) => + { + var prompt = request?.Messages?.LastOrDefault()?.Content + .OfType().FirstOrDefault()?.Text ?? ""; + return new ValueTask(new CreateMessageResult + { + Content = [new TextContentBlock { Text = $"LLM:{prompt}" }], + Model = "test-model" + }); + }; + options.Handlers.RootsHandler = (request, ct) => + { + return new ValueTask(new ListRootsResult + { + Roots = [ + new Root { Uri = "file:///project", Name = "Project" }, + new Root { Uri = "file:///data", Name = "Data" } + ] + }); + }; + } + + // ===================================================================== + // MRTR tests: experimental (native), backcompat (legacy JSON-RPC), and edge cases. + // Each test creates its own server with DRAFT-2026-v1 enabled. + // ===================================================================== + + [McpServerTool(Name = "mrtr-mixed")] + private static async Task MrtrMixed(McpServer server, RequestContext context, CancellationToken ct) + { + var state = context.Params!.RequestState; + var responses = context.Params!.InputResponses; + + // Round 3 entry: confirmation from round 2 available. Transition to await API. + if (state == "round-2" && responses?.TryGetValue("confirm", out var confirmResponse) == true) + { + var confirmation = confirmResponse.Deserialize(InputResponse.ElicitResultJsonTypeInfo)?.Action ?? "unknown"; + + // Await API: sequential sampling then elicitation + var sampleResult = await server.SampleAsync(new CreateMessageRequestParams + { + Messages = [new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = "Write greeting" }] }], + MaxTokens = 100 + }, ct); + var greeting = sampleResult.Content.OfType().FirstOrDefault()?.Text ?? ""; + + var signoffResult = await server.ElicitAsync(new ElicitRequestParams + { + Message = "Sign off as?", + RequestedSchema = new() + }, ct); + var signoff = signoffResult.Action; + + return $"{confirmation}|{greeting}|{signoff}"; + } + + // Round 2 entry: parallel results from round 1 available. + if (state == "round-1" && responses is not null) + { + var name = responses["name"].Deserialize(InputResponse.ElicitResultJsonTypeInfo)?.Content?.FirstOrDefault().Value; + var weather = responses["weather"].Deserialize(InputResponse.CreateMessageResultJsonTypeInfo)?.Content + .OfType().FirstOrDefault()?.Text ?? ""; + var root = responses["roots"].Deserialize(InputResponse.ListRootsResultJsonTypeInfo)?.Roots?.FirstOrDefault()?.Name ?? ""; + + // Exception API: single elicitation with requestState + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["confirm"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = $"Confirm {name} in {weather} near {root}?", + RequestedSchema = new() + }) + }, + requestState: "round-2"); + } + + // Round 1: Exception API with 3 PARALLEL input requests + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["name"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "What is your name?", + RequestedSchema = new() + }), + ["weather"] = InputRequest.ForSampling(new CreateMessageRequestParams + { + Messages = [new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = "Describe the weather" }] }], + MaxTokens = 100 + }), + ["roots"] = InputRequest.ForRootsList(new ListRootsRequestParams()) + }, + requestState: "round-1"); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task Mrtr_MixedExceptionAndAwaitStyle(bool experimentalClient) + { + // The server always supports DRAFT-2026-v1 (it's in SupportedProtocolVersions). The + // client opts in by pinning ProtocolVersion = "DRAFT-2026-v1"; otherwise it negotiates + // the latest non-draft version and the server falls back to the exception path with + // legacy JSON-RPC resolution. + var messageTracker = ConfigureServer(MrtrMixed); + + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + + Action configureClient = experimentalClient + ? options => { ConfigureMrtrHandlers(options); options.ProtocolVersion = "DRAFT-2026-v1"; } + : ConfigureMrtrHandlers; + + // The await-style portion of this tool calls server.SampleAsync/ElicitAsync on round 3. + // In stateless mode, those calls succeed only when the request is still open on the same + // SSE stream — which it is — so the tool runs end-to-end as long as the input requests + // themselves can be resolved (MRTR client) or replayed via legacy JSON-RPC (stateful + legacy). + if (Stateless && !experimentalClient) + { + // Stateless + legacy client: InputRequiredException cannot be resolved (no MRTR wire + // and no persistent server instance for the backcompat retry loop). The server returns + // a JSON-RPC error. + await using var client = await ConnectAsync(configureClient: configureClient); + var ex = await Assert.ThrowsAsync(() => + client.CallToolAsync("mrtr-mixed", + cancellationToken: TestContext.Current.CancellationToken).AsTask()); + + Assert.Equal(McpErrorCode.InternalError, ex.ErrorCode); + Assert.Contains("stateless", ex.Message, StringComparison.OrdinalIgnoreCase); + Assert.Contains("MRTR", ex.Message); + return; + } + + if (Stateless && experimentalClient) + { + // Stateless + MRTR client: the await-style portion (server.SampleAsync on round 3) + // requires handler suspension across requests, which only works in stateful mode. + // Skip this combination — the await API is documented as stateful-only. + Assert.SkipWhen(true, "Await-style API requires handler suspension (stateful only)."); + return; + } + + // Stateful path — both client modes complete all 3 rounds. + await using var statefulClient = await ConnectAsync(configureClient: configureClient); + + Assert.Equal(experimentalClient ? "DRAFT-2026-v1" : "2025-11-25", + statefulClient.NegotiatedProtocolVersion); + + var result = await statefulClient.CallToolAsync("mrtr-mixed", + cancellationToken: TestContext.Current.CancellationToken); + + var text = Assert.IsType(Assert.Single(result.Content)).Text; + Assert.True(result.IsError is not true); + var parts = text.Split('|'); + Assert.Equal(3, parts.Length); + Assert.Equal("accept", parts[0]); + Assert.StartsWith("LLM:", parts[1]); + Assert.Equal("accept", parts[2]); + + if (experimentalClient) + { + // Rounds 1-2 use wire-format MRTR (InputRequiredResult), but round 3's await calls + // still issue legacy elicitation/create + sampling/createMessage requests, so this + // configuration is mixed-mode. + messageTracker.AssertMrtrUsedAtLeastOnce(); + } + else + { + messageTracker.AssertMrtrNotUsed(); + } + } + + [McpServerTool(Name = "mrtr-parallel-await")] + private static async Task MrtrParallelAwait(McpServer server, CancellationToken ct) + { + var elicitTask = server.ElicitAsync(new ElicitRequestParams + { + Message = "Parallel elicit", + RequestedSchema = new() + }, ct); + + var sampleTask = server.SampleAsync(new CreateMessageRequestParams + { + Messages = [new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = "Parallel sample" }] }], + MaxTokens = 100 + }, ct); + + var sampleResult = await sampleTask; + var elicitResult = await elicitTask; + return $"parallel-ok:{elicitResult.Action}:{sampleResult.Content.OfType().First().Text}"; + } + + [Fact] + public async Task Mrtr_ParallelAwaits() + { + // Server-side parallel ElicitAsync + SampleAsync awaits use the legacy server-to-client + // request path on stateful sessions, which works the same under either negotiated revision + // (the spec only removes those request methods from Streamable HTTP under draft, which is + // stateless-only territory). Stateless servers can't issue server-to-client requests at all. + Assert.SkipWhen(Stateless, "Server-side awaits require stateful server-to-client requests."); + + ConfigureServer(MrtrParallelAwait); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var client = await ConnectAsync(configureClient: ConfigureMrtrHandlers); + + var result = await client.CallToolAsync("mrtr-parallel-await", + cancellationToken: TestContext.Current.CancellationToken); + + var text = Assert.IsType(Assert.Single(result.Content)).Text; + Assert.StartsWith("parallel-ok:", text); + Assert.True(result.IsError is not true); + } + + [McpServerTool(Name = "mrtr-elicit")] + private static string MrtrElicit(RequestContext context) + { + if (context.Params!.InputResponses is { } responses && + responses.TryGetValue("user_input", out var response)) + { + return $"elicit-ok:{response.Deserialize(InputResponse.ElicitResultJsonTypeInfo)?.Action}"; + } + + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["user_input"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "Please confirm", + RequestedSchema = new() + }) + }, + requestState: "elicit-state"); + } + + [Fact] + public async Task Mrtr_Roots_CompletesViaMrtr() + { + var messageTracker = ConfigureServer( + [McpServerTool(Name = "mrtr-roots")] (RequestContext context) => + { + if (context.Params!.InputResponses is { } responses && + responses.TryGetValue("roots", out var response)) + { + var roots = response.Deserialize(InputResponse.ListRootsResultJsonTypeInfo)?.Roots; + return $"roots-ok:{string.Join(",", roots?.Select(r => r.Uri) ?? [])}"; + } + + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["roots"] = InputRequest.ForRootsList(new ListRootsRequestParams()) + }, + requestState: "roots-state"); + }); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + await using var client = await ConnectExperimentalAsync(); + Assert.Equal("DRAFT-2026-v1", client.NegotiatedProtocolVersion); + + var result = await client.CallToolAsync("mrtr-roots", + cancellationToken: TestContext.Current.CancellationToken); + + var text = Assert.IsType(Assert.Single(result.Content)).Text; + Assert.Equal("roots-ok:file:///project,file:///data", text); + Assert.True(result.IsError is not true); + messageTracker.AssertMrtrUsed(); + } + + [McpServerTool(Name = "mrtr-multi")] + private static string MrtrMulti(RequestContext context) + { + var requestState = context.Params!.RequestState; + var inputResponses = context.Params!.InputResponses; + + if (requestState == "round-2" && inputResponses is not null) + { + var greeting = inputResponses["greeting"].Deserialize(InputResponse.ElicitResultJsonTypeInfo)?.Action; + return $"multi-done:greeting={greeting}"; + } + + if (requestState == "round-1" && inputResponses is not null) + { + var name = inputResponses["name"].Deserialize(InputResponse.ElicitResultJsonTypeInfo)?.Content?.FirstOrDefault().Value; + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["greeting"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = $"How should I greet {name}?", + RequestedSchema = new() + }) + }, + requestState: "round-2"); + } + + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["name"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "What is your name?", + RequestedSchema = new() + }) + }, + requestState: "round-1"); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task Mrtr_MultiRoundTrip_Completes(bool experimentalClient) + { + var messageTracker = ConfigureServer(MrtrMulti); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + + // Configure client — experimental or default based on parameter. + Action configureClient = experimentalClient + ? options => { ConfigureMrtrHandlers(options); options.ProtocolVersion = "DRAFT-2026-v1"; } + : ConfigureMrtrHandlers; + await using var client = await ConnectAsync(configureClient: configureClient); + + if (!experimentalClient && Stateless) + { + // Stateless without MRTR: InputRequiredException can't be resolved + // (no MRTR negotiated and no stateful backcompat path). + var ex = await Assert.ThrowsAsync(() => + client.CallToolAsync("mrtr-multi", + cancellationToken: TestContext.Current.CancellationToken).AsTask()); + Assert.Equal(McpErrorCode.InternalError, ex.ErrorCode); + return; + } + + var result = await client.CallToolAsync("mrtr-multi", + cancellationToken: TestContext.Current.CancellationToken); + + var text = Assert.IsType(Assert.Single(result.Content)).Text; + Assert.Equal("multi-done:greeting=accept", text); + Assert.True(result.IsError is not true); + + if (experimentalClient) + { + Assert.Equal("DRAFT-2026-v1", client.NegotiatedProtocolVersion); + messageTracker.AssertMrtrUsed(); + } + else + { + Assert.Equal("2025-11-25", client.NegotiatedProtocolVersion); + messageTracker.AssertMrtrNotUsed(); + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task Mrtr_IsMrtrSupported(bool experimentalClient) + { + ConfigureServer([McpServerTool(Name = "mrtr-check")] (McpServer server) => server.IsMrtrSupported.ToString()); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + + // Configure client — experimental or default based on parameter. + Action configureClient = experimentalClient + ? options => { ConfigureMrtrHandlers(options); options.ProtocolVersion = "DRAFT-2026-v1"; } + : ConfigureMrtrHandlers; + await using var client = await ConnectAsync(configureClient: configureClient); + Assert.Equal(experimentalClient ? "DRAFT-2026-v1" : "2025-11-25", client.NegotiatedProtocolVersion); + + var result = await client.CallToolAsync("mrtr-check", + cancellationToken: TestContext.Current.CancellationToken); + + // IsMrtrSupported is false only when stateless AND client didn't negotiate MRTR + // (no backcompat path available). All other combos have MRTR or backcompat support. + var expected = Stateless && !experimentalClient ? "False" : "True"; + var text = Assert.IsType(Assert.Single(result.Content)).Text; + Assert.Equal(expected, text); + } + + [McpServerTool(Name = "mrtr-concurrent-three")] + private static string MrtrConcurrentThree(RequestContext context) + { + if (context.Params!.InputResponses is { Count: 3 } responses && + responses.ContainsKey("elicit") && + responses.ContainsKey("sample") && + responses.ContainsKey("roots")) + { + var elicitAction = responses["elicit"].Deserialize(InputResponse.ElicitResultJsonTypeInfo)?.Action; + var sampleText = responses["sample"].Deserialize(InputResponse.CreateMessageResultJsonTypeInfo)? + .Content.OfType().FirstOrDefault()?.Text; + var rootUris = string.Join(",", + responses["roots"].Deserialize(InputResponse.ListRootsResultJsonTypeInfo)?.Roots.Select(r => r.Uri) ?? []); + return $"all-ok:elicit={elicitAction},sample={sampleText},roots={rootUris}"; + } + + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["elicit"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "Confirm action", + RequestedSchema = new() + }), + ["sample"] = InputRequest.ForSampling(new CreateMessageRequestParams + { + Messages = [new SamplingMessage + { + Role = Role.User, + Content = [new TextContentBlock { Text = "Generate summary" }] + }], + MaxTokens = 50 + }), + ["roots"] = InputRequest.ForRootsList(new ListRootsRequestParams()) + }, + requestState: "concurrent-state"); + } + + [Fact] + public async Task Mrtr_ConcurrentThreeInputs_ResolvedSimultaneously() + { + var messageTracker = ConfigureServer(MrtrConcurrentThree); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + + var elicitCalled = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var samplingCalled = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var rootsCalled = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + await using var client = await ConnectAsync(configureClient: options => + { + options.ProtocolVersion = "DRAFT-2026-v1"; + options.Handlers.ElicitationHandler = async (request, ct) => + { + elicitCalled.TrySetResult(); + await Task.WhenAll(samplingCalled.Task.WaitAsync(ct), rootsCalled.Task.WaitAsync(ct)); + return new ElicitResult { Action = "accept" }; + }; + options.Handlers.SamplingHandler = async (request, progress, ct) => + { + samplingCalled.TrySetResult(); + await Task.WhenAll(elicitCalled.Task.WaitAsync(ct), rootsCalled.Task.WaitAsync(ct)); + return new CreateMessageResult + { + Content = [new TextContentBlock { Text = "AI-summary" }], + Model = "test-model" + }; + }; + options.Handlers.RootsHandler = async (request, ct) => + { + rootsCalled.TrySetResult(); + await Task.WhenAll(elicitCalled.Task.WaitAsync(ct), samplingCalled.Task.WaitAsync(ct)); + return new ListRootsResult + { + Roots = [new Root { Uri = "file:///workspace", Name = "Workspace" }] + }; + }; + }); + Assert.Equal("DRAFT-2026-v1", client.NegotiatedProtocolVersion); + + var result = await client.CallToolAsync("mrtr-concurrent-three", + cancellationToken: TestContext.Current.CancellationToken); + + var text = Assert.IsType(Assert.Single(result.Content)).Text; + Assert.Equal("all-ok:elicit=accept,sample=AI-summary,roots=file:///workspace", text); + Assert.True(result.IsError is not true); + messageTracker.AssertMrtrUsed(); + } + + [Fact] + public async Task Mrtr_LoadShedding_RequestStateOnly_CompletesViaMrtr() + { + var messageTracker = ConfigureServer( + [McpServerTool(Name = "mrtr-loadshed")] (RequestContext context) => + { + if (context.Params!.RequestState is { } state) + { + return $"resumed:{state}"; + } + + // requestState-only InputRequiredException (no inputRequests) + throw new InputRequiredException(requestState: "deferred-work"); + }); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + await using var client = await ConnectExperimentalAsync(); + Assert.Equal("DRAFT-2026-v1", client.NegotiatedProtocolVersion); + + var result = await client.CallToolAsync("mrtr-loadshed", + cancellationToken: TestContext.Current.CancellationToken); + + var text = Assert.IsType(Assert.Single(result.Content)).Text; + Assert.Equal("resumed:deferred-work", text); + Assert.True(result.IsError is not true); + messageTracker.AssertMrtrUsed(); + } + + [Fact] + public async Task Mrtr_Backcompat_Roots_ResolvedViaLegacyJsonRpc() + { + Assert.SkipWhen(Stateless, "Backcompat requires stateful server for legacy JSON-RPC."); + var messageTracker = ConfigureServer( + [McpServerTool(Name = "mrtr-roots-backcompat")] (RequestContext context) => + { + if (context.Params!.InputResponses is { } responses && + responses.TryGetValue("roots", out var response)) + { + var roots = response.Deserialize(InputResponse.ListRootsResultJsonTypeInfo)?.Roots; + return $"roots-ok:{roots?.FirstOrDefault()?.Name}"; + } + + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["roots"] = InputRequest.ForRootsList(new ListRootsRequestParams()) + }, + requestState: "roots-state"); + }); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + await using var client = await ConnectDefaultAsync(); + Assert.Equal("2025-11-25", client.NegotiatedProtocolVersion); + + var result = await client.CallToolAsync("mrtr-roots-backcompat", + cancellationToken: TestContext.Current.CancellationToken); + + var text = Assert.IsType(Assert.Single(result.Content)).Text; + Assert.Equal("roots-ok:Project", text); + Assert.True(result.IsError is not true); + messageTracker.AssertMrtrNotUsed(); + } + + [Fact] + public async Task Mrtr_Backcompat_MultipleInputRequests_ResolvedViaLegacyJsonRpc() + { + Assert.SkipWhen(Stateless, "Backcompat requires stateful server for legacy JSON-RPC."); + var messageTracker = ConfigureServer( + [McpServerTool(Name = "mrtr-multi-input")] (RequestContext context) => + { + if (context.Params!.InputResponses is { } responses && + responses.TryGetValue("confirm", out var elicitResponse) && + responses.TryGetValue("summarize", out var sampleResponse)) + { + var action = elicitResponse.Deserialize(InputResponse.ElicitResultJsonTypeInfo)?.Action; + var text = sampleResponse.Deserialize(InputResponse.CreateMessageResultJsonTypeInfo)?.Content.OfType().FirstOrDefault()?.Text; + return $"both:{action}:{text}"; + } + + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["confirm"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "Please confirm", + RequestedSchema = new() + }), + ["summarize"] = InputRequest.ForSampling(new CreateMessageRequestParams + { + Messages = [new SamplingMessage + { + Role = Role.User, + Content = [new TextContentBlock { Text = "Summarize" }] + }], + MaxTokens = 100 + }) + }, + requestState: "multi-input-state"); + }); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + await using var client = await ConnectDefaultAsync(); + Assert.Equal("2025-11-25", client.NegotiatedProtocolVersion); + + var result = await client.CallToolAsync("mrtr-multi-input", + cancellationToken: TestContext.Current.CancellationToken); + + var text = Assert.IsType(Assert.Single(result.Content)).Text; + Assert.Equal("both:accept:LLM:Summarize", text); + Assert.True(result.IsError is not true); + messageTracker.AssertMrtrNotUsed(); + } + + [Fact] + public async Task Mrtr_Backcompat_AlwaysIncomplete_FailsAfterMaxRetries() + { + Assert.SkipWhen(Stateless, "Backcompat requires stateful server for legacy JSON-RPC."); + int elicitCallCount = 0; + + ConfigureServer( + [McpServerTool(Name = "mrtr-always-incomplete")] (RequestContext context) => + { + // Always throw — never complete + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["confirm"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "Confirm again", + RequestedSchema = new() + }) + }, + requestState: "infinite"); + }); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + await using var client = await ConnectAsync(configureClient: options => + { + ConfigureMrtrHandlers(options); + var originalHandler = options.Handlers.ElicitationHandler!; + options.Handlers.ElicitationHandler = (request, ct) => + { + Interlocked.Increment(ref elicitCallCount); + return originalHandler(request, ct); + }; + }); + Assert.Equal("2025-11-25", client.NegotiatedProtocolVersion); + + var ex = await Assert.ThrowsAsync(() => + client.CallToolAsync("mrtr-always-incomplete", + cancellationToken: TestContext.Current.CancellationToken).AsTask()); + + Assert.Contains("exceeded", ex.Message, StringComparison.OrdinalIgnoreCase); + Assert.Contains("10", ex.Message); + Assert.Equal(10, elicitCallCount); + } + + [Fact] + public async Task Mrtr_Backcompat_EmptyInputRequests_FailsWithError() + { + Assert.SkipWhen(Stateless, "Backcompat requires stateful server for legacy JSON-RPC."); + ConfigureServer( + [McpServerTool(Name = "mrtr-empty-inputs")] (RequestContext context) => + { + throw new InputRequiredException( + inputRequests: new Dictionary(), + requestState: "empty"); + }); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + await using var client = await ConnectDefaultAsync(); + Assert.Equal("2025-11-25", client.NegotiatedProtocolVersion); + + var ex = await Assert.ThrowsAsync(() => + client.CallToolAsync("mrtr-empty-inputs", + cancellationToken: TestContext.Current.CancellationToken).AsTask()); + + Assert.Contains("without input requests", ex.Message, StringComparison.OrdinalIgnoreCase); + Assert.Equal(McpErrorCode.InternalError, ex.ErrorCode); + } + + [Fact] + public async Task Mrtr_Backcompat_ClientHandlerThrows_PropagatesError() + { + Assert.SkipWhen(Stateless, "Backcompat requires stateful server for legacy JSON-RPC."); + + ConfigureServer(MrtrElicit); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + await using var client = await ConnectAsync(configureClient: options => + { + ConfigureMrtrHandlers(options); + options.Handlers.ElicitationHandler = (request, ct) => + { + throw new InvalidOperationException("Client-side elicitation failure"); + }; + }); + Assert.Equal("2025-11-25", client.NegotiatedProtocolVersion); + + // Handler exception propagates through the backcompat JSON-RPC round-trip. + // The original exception message gets wrapped in "Request failed (remote)" during backcompat. + var ex = await Assert.ThrowsAsync(() => + client.CallToolAsync("mrtr-elicit", + cancellationToken: TestContext.Current.CancellationToken).AsTask()); + Assert.Equal(McpErrorCode.InternalError, ex.ErrorCode); + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index 678b27022..b9b8381ca 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -14,7 +14,7 @@ namespace ModelContextProtocol.AspNetCore.Tests; -public abstract class MapMcpTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper) +public abstract partial class MapMcpTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper) { protected abstract bool UseStreamableHttp { get; } protected abstract bool Stateless { get; } @@ -27,9 +27,8 @@ protected virtual void ConfigureStateless(HttpServerTransportOptions options) protected async Task ConnectAsync( string? path = null, HttpClientTransportOptions? transportOptions = null, - McpClientOptions? clientOptions = null) + Action? configureClient = null) { - // Default behavior when no options are provided path ??= UseStreamableHttp ? "/" : "/sse"; await using var transport = new HttpClientTransport(transportOptions ?? new HttpClientTransportOptions @@ -38,6 +37,8 @@ protected async Task ConnectAsync( TransportMode = UseStreamableHttp ? HttpTransportMode.StreamableHttp : HttpTransportMode.Sse, }, HttpClient, LoggerFactory); + var clientOptions = new McpClientOptions(); + configureClient?.Invoke(clientOptions); return await McpClient.CreateAsync(transport, clientOptions, LoggerFactory, TestContext.Current.CancellationToken); } @@ -156,29 +157,24 @@ public async Task Sampling_DoesNotCloseStreamPrematurely() await app.StartAsync(TestContext.Current.CancellationToken); var sampleCount = 0; - var clientOptions = new McpClientOptions() + await using var mcpClient = await ConnectAsync(configureClient: options => { - Handlers = new() + options.Handlers.SamplingHandler = async (parameters, _, _) => { - SamplingHandler = async (parameters, _, _) => - { - Assert.NotNull(parameters?.Messages); - var message = Assert.Single(parameters.Messages); - Assert.Equal(Role.User, message.Role); - Assert.Equal("Test prompt for sampling", Assert.IsType(Assert.Single(message.Content)).Text); + Assert.NotNull(parameters?.Messages); + var message = Assert.Single(parameters.Messages); + Assert.Equal(Role.User, message.Role); + Assert.Equal("Test prompt for sampling", Assert.IsType(Assert.Single(message.Content)).Text); - sampleCount++; - return new CreateMessageResult - { - Model = "test-model", - Role = Role.Assistant, - Content = [new TextContentBlock { Text = "Sampling response from client" }], - }; - } - } - }; - - await using var mcpClient = await ConnectAsync(clientOptions: clientOptions); + sampleCount++; + return new CreateMessageResult + { + Model = "test-model", + Role = Role.Assistant, + Content = [new TextContentBlock { Text = "Sampling response from client" }], + }; + }; + }); var result = await mcpClient.CallToolAsync("sampling-tool", new Dictionary { @@ -375,7 +371,11 @@ public async Task OutgoingFilter_SeesResponsesAndRequests() }, }; - await using var client = await ConnectAsync(clientOptions: clientOptions); + await using var client = await ConnectAsync(configureClient: opts => + { + opts.Capabilities = clientOptions.Capabilities; + opts.Handlers = clientOptions.Handlers; + }); await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); await client.CallToolAsync("echo_claims_principal", @@ -385,10 +385,12 @@ await client.CallToolAsync("sampling-tool", new Dictionary { ["prompt"] = "Hello" }, cancellationToken: TestContext.Current.CancellationToken); - Assert.Contains("initialize-response", observedMessageTypes); - Assert.Contains("tools-list-response", observedMessageTypes); - Assert.Contains("tool-call-response", observedMessageTypes); - Assert.Contains($"request:{RequestMethods.SamplingCreateMessage}", observedMessageTypes); + // Exact counts catch regressions where the outgoing filter pipeline gets applied more than once + // per outbound message (e.g., SendRequestAsync double-wrapping SendToRelatedTransportAsync). + Assert.Equal(1, observedMessageTypes.Count(m => m == "initialize-response")); + Assert.Equal(1, observedMessageTypes.Count(m => m == "tools-list-response")); + Assert.Equal(2, observedMessageTypes.Count(m => m == "tool-call-response")); // one per CallToolAsync + Assert.Equal(2, observedMessageTypes.Count(m => m == $"request:{RequestMethods.SamplingCreateMessage}")); // sampling-tool makes two SampleAsync calls } [Fact] @@ -496,6 +498,7 @@ public async Task OutgoingFilter_CanSendAdditionalMessages() Assert.Equal("injected", extraMessage); } + private ClaimsPrincipal CreateUser(string name) => new(new ClaimsIdentity( [new Claim("name", name), new Claim(ClaimTypes.NameIdentifier, name)], @@ -566,4 +569,5 @@ public static async Task LongRunningOperation( return $"Operation completed after {durationMs}ms"; } } + } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MrtrProtocolTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MrtrProtocolTests.cs new file mode 100644 index 000000000..76625f654 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MrtrProtocolTests.cs @@ -0,0 +1,349 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.Net; +using System.Net.ServerSentEvents; +using System.Text; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization.Metadata; + +namespace ModelContextProtocol.AspNetCore.Tests; + +/// +/// Protocol-level tests for Multi Round-Trip Requests (MRTR). +/// These tests send raw JSON-RPC requests via HTTP and verify protocol-level behavior +/// including InputRequiredResult structure, retry with inputResponses, and error handling. +/// +public class MrtrProtocolTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper), IAsyncDisposable +{ + private WebApplication? _app; + + private async Task StartAsync() + { + Builder.Services.AddMcpServer(options => + { + options.ServerInfo = new Implementation + { + Name = nameof(MrtrProtocolTests), + Version = "1", + }; + options.ProtocolVersion = "DRAFT-2026-v1"; + }).WithTools([ + McpServerTool.Create( + async (string message, McpServer server, CancellationToken ct) => + { + var result = await server.ElicitAsync(new ElicitRequestParams + { + Message = message, + RequestedSchema = new() + }, ct); + + return $"{result.Action}:{result.Content?.FirstOrDefault().Value}"; + }, + new McpServerToolCreateOptions + { + Name = "elicit-tool", + Description = "Elicits from client" + }), + McpServerTool.Create( + static string (McpServer _) => throw new McpProtocolException("Tool validation failed", McpErrorCode.InvalidParams), + new McpServerToolCreateOptions + { + Name = "throwing-tool", + Description = "A tool that throws immediately" + }), + ]).WithHttpTransport(); + + _app = Builder.Build(); + _app.MapMcp(); + await _app.StartAsync(TestContext.Current.CancellationToken); + + HttpClient.DefaultRequestHeaders.Accept.Add(new("application/json")); + HttpClient.DefaultRequestHeaders.Accept.Add(new("text/event-stream")); + } + + public async ValueTask DisposeAsync() + { + if (_app is not null) + { + await _app.DisposeAsync(); + } + base.Dispose(); + } + + [Fact] + public async Task ToolThatThrows_ReturnsJsonRpcError_NotIncompleteResult() + { + await StartAsync(); + await InitializeWithMrtrAsync(); + + var response = await PostJsonRpcAsync(CallTool("throwing-tool")); + + // Should be a JSON-RPC error, not an InputRequiredResult + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + var sseData = Assert.Single(await ReadSseAsync(response.Content).ToListAsync(TestContext.Current.CancellationToken)); + var message = JsonSerializer.Deserialize(sseData, McpJsonUtilities.DefaultOptions); + var error = Assert.IsType(message); + Assert.Equal((int)McpErrorCode.InvalidParams, error.Error.Code); + Assert.Contains("Tool validation failed", error.Error.Message); + } + + /// + /// Regression test for a CI hang where the server-side MRTR backcompat resolver routed its + /// outgoing roots/list request through the session-level transport, which silently + /// dropped the message when the client's GET stream had not been established yet. The + /// outgoing request must instead go through the POST's response stream (the request's + /// ) so it + /// reaches the client without depending on the GET stream at all. + /// + /// This test deliberately never opens a GET stream — it only POSTs the initialize, the + /// initialized notification, the tools/call, and the roots/list response. If the + /// server falls back to _transport.SendMessageAsync, the test times out instead of + /// reading the expected roots/list SSE event off the tools/call POST response. + /// + [Fact] + public async Task BackcompatResolver_SendsServerRequestOverPostStream_WithoutGetStream() + { + // Configure a server that does NOT pin DRAFT-2026-v1 so it can negotiate the current + // protocol with a legacy client. The backcompat resolver path only runs when the + // negotiated version is not DRAFT-2026-v1. + Builder.Services.AddMcpServer(options => + { + options.ServerInfo = new Implementation + { + Name = nameof(MrtrProtocolTests), + Version = "1", + }; + }).WithTools([ + McpServerTool.Create( + static string (RequestContext context) => + { + if (context.Params!.InputResponses is { } responses && + responses.TryGetValue("roots", out var response)) + { + var roots = response.Deserialize(InputResponse.ListRootsResultJsonTypeInfo)?.Roots; + return $"roots-ok:{roots?.FirstOrDefault()?.Name}"; + } + + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["roots"] = InputRequest.ForRootsList(new ListRootsRequestParams()) + }, + requestState: "roots-state"); + }, + new McpServerToolCreateOptions + { + Name = "backcompat-roots-tool", + Description = "Throws InputRequiredException so the server's backcompat resolver issues a roots/list", + }), + ]).WithHttpTransport(); + + _app = Builder.Build(); + _app.MapMcp(); + await _app.StartAsync(TestContext.Current.CancellationToken); + + HttpClient.DefaultRequestHeaders.Accept.Add(new("application/json")); + HttpClient.DefaultRequestHeaders.Accept.Add(new("text/event-stream")); + + // Initialize with the current (non-draft) protocol so the server's backcompat resolver runs. + var initJson = """ + {"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{"roots":{}},"clientInfo":{"name":"BackcompatTestClient","version":"1.0.0"}}} + """; + + string sessionId; + using (var initResponse = await PostJsonRpcAsync(initJson)) + { + var initRpcResponse = await AssertSingleSseResponseAsync(initResponse); + Assert.NotNull(initRpcResponse.Result); + Assert.Equal("2025-11-25", initRpcResponse.Result["protocolVersion"]?.GetValue()); + + sessionId = Assert.Single(initResponse.Headers.GetValues("mcp-session-id")); + } + + HttpClient.DefaultRequestHeaders.Remove("mcp-session-id"); + HttpClient.DefaultRequestHeaders.Add("mcp-session-id", sessionId); + HttpClient.DefaultRequestHeaders.Remove("MCP-Protocol-Version"); + HttpClient.DefaultRequestHeaders.Add("MCP-Protocol-Version", "2025-11-25"); + + // Send the initialized notification. + using (var initializedResponse = await PostJsonRpcAsync( + """{"jsonrpc":"2.0","method":"notifications/initialized"}""")) + { + Assert.True(initializedResponse.IsSuccessStatusCode); + } + + _lastRequestId = 1; + + // POST the tools/call and start reading the response SSE stream. We deliberately do NOT + // open a GET stream — the server-to-client roots/list must be delivered on this POST's + // response. Use HttpCompletionOption.ResponseHeadersRead so the POST returns as soon as + // the response headers arrive instead of waiting for the SSE stream to close. + var callRequest = new HttpRequestMessage(HttpMethod.Post, (string?)null) + { + Content = JsonContent(CallTool("backcompat-roots-tool")), + }; + callRequest.Content.Headers.Add("Mcp-Method", "tools/call"); + callRequest.Content.Headers.Add("Mcp-Name", "backcompat-roots-tool"); + + using var callResponse = await HttpClient.SendAsync( + callRequest, + HttpCompletionOption.ResponseHeadersRead, + TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.OK, callResponse.StatusCode); + Assert.Equal("text/event-stream", callResponse.Content.Headers.ContentType?.MediaType); + + var sseEvents = ReadSseAsync(callResponse.Content) + .GetAsyncEnumerator(TestContext.Current.CancellationToken); + + try + { + // First SSE event on this POST should be the server-initiated roots/list request. + Assert.True(await sseEvents.MoveNextAsync(), + "Server did not send a roots/list request on the tools/call POST response stream. " + + "If this hangs/times out, the MRTR backcompat resolver is routing the outgoing request " + + "through the session-level transport instead of the POST's RelatedTransport."); + + var rootsRequestNode = JsonNode.Parse(sseEvents.Current) as JsonObject; + Assert.NotNull(rootsRequestNode); + Assert.Equal("roots/list", rootsRequestNode["method"]?.GetValue()); + var rootsRequestId = rootsRequestNode["id"]; + Assert.NotNull(rootsRequestId); + + // POST the roots/list response on a separate connection. The server's pending + // RequestRootsAsync await will complete and the backcompat resolver will retry the tool. + var rootsIdLiteral = rootsRequestId.ToJsonString(); + var rootsResponseJson = + "{\"jsonrpc\":\"2.0\",\"id\":" + rootsIdLiteral + + ",\"result\":{\"roots\":[{\"uri\":\"file:///workspace\",\"name\":\"Workspace\"}]}}"; + using (var rootsResponseHttp = await PostJsonRpcAsync(rootsResponseJson)) + { + Assert.True(rootsResponseHttp.IsSuccessStatusCode); + } + + // Next SSE event on the original POST should be the final tools/call response. + Assert.True(await sseEvents.MoveNextAsync(), "Server did not return the final tools/call response."); + var finalResponse = JsonSerializer.Deserialize(sseEvents.Current, GetJsonTypeInfo()); + Assert.NotNull(finalResponse); + Assert.NotNull(finalResponse.Result); + + var content = finalResponse.Result["content"]?.AsArray(); + Assert.NotNull(content); + var firstContent = Assert.Single(content); + Assert.Equal("roots-ok:Workspace", firstContent?["text"]?.GetValue()); + } + finally + { + await sseEvents.DisposeAsync(); + } + } + + // --- Helpers --- + + private static StringContent JsonContent(string json) => new(json, Encoding.UTF8, "application/json"); + private static JsonTypeInfo GetJsonTypeInfo() => (JsonTypeInfo)McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(T)); + + private static async IAsyncEnumerable ReadSseAsync(HttpContent responseContent) + { + var responseStream = await responseContent.ReadAsStreamAsync(TestContext.Current.CancellationToken); + await foreach (var sseItem in SseParser.Create(responseStream).EnumerateAsync(TestContext.Current.CancellationToken)) + { + Assert.Equal("message", sseItem.EventType); + yield return sseItem.Data; + } + } + + private static async Task AssertSingleSseResponseAsync(HttpResponseMessage response) + { + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.Equal("text/event-stream", response.Content.Headers.ContentType?.MediaType); + + var sseItem = Assert.Single(await ReadSseAsync(response.Content).ToListAsync(TestContext.Current.CancellationToken)); + var jsonRpcResponse = JsonSerializer.Deserialize(sseItem, GetJsonTypeInfo()); + + Assert.NotNull(jsonRpcResponse); + return jsonRpcResponse; + } + + private Task PostJsonRpcAsync(string json) + { + var content = JsonContent(json); + + // DRAFT-2026-v1 requires Mcp-Method and (for tools/call) Mcp-Name headers per SEP-2243. + // Parse the body to derive them and attach to this request only. + var bodyNode = JsonNode.Parse(json); + if (bodyNode is JsonObject obj) + { + if (obj["method"]?.GetValue() is { } method) + { + content.Headers.Add("Mcp-Method", method); + + if (obj["params"] is JsonObject paramsObj) + { + string? mcpName = method switch + { + "tools/call" or "prompts/get" => paramsObj["name"]?.GetValue(), + "resources/read" => paramsObj["uri"]?.GetValue(), + _ => null, + }; + if (mcpName is not null) + { + content.Headers.Add("Mcp-Name", mcpName); + } + } + } + } + + return HttpClient.PostAsync("", content, TestContext.Current.CancellationToken); + } + + private long _lastRequestId = 1; + + private string Request(string method, string parameters = "{}") + { + var id = Interlocked.Increment(ref _lastRequestId); + return $$""" + {"jsonrpc":"2.0","id":{{id}},"method":"{{method}}","params":{{parameters}}} + """; + } + + private string CallTool(string toolName, string arguments = "{}") => + Request("tools/call", $$""" + {"name":"{{toolName}}","arguments":{{arguments}}} + """); + + /// + /// Initialize a session requesting the experimental protocol version that enables MRTR. + /// + private async Task InitializeWithMrtrAsync() + { + var initJson = """ + {"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"DRAFT-2026-v1","capabilities":{"sampling":{},"elicitation":{},"roots":{}},"clientInfo":{"name":"MrtrTestClient","version":"1.0.0"}}} + """; + + using var response = await PostJsonRpcAsync(initJson); + var rpcResponse = await AssertSingleSseResponseAsync(response); + Assert.NotNull(rpcResponse.Result); + + // Verify the server negotiated to the experimental version + var protocolVersion = rpcResponse.Result["protocolVersion"]?.GetValue(); + Assert.Equal("DRAFT-2026-v1", protocolVersion); + + var sessionId = Assert.Single(response.Headers.GetValues("mcp-session-id")); + HttpClient.DefaultRequestHeaders.Remove("mcp-session-id"); + HttpClient.DefaultRequestHeaders.Add("mcp-session-id", sessionId); + + // Set the MCP-Protocol-Version header for subsequent requests + HttpClient.DefaultRequestHeaders.Remove("MCP-Protocol-Version"); + HttpClient.DefaultRequestHeaders.Add("MCP-Protocol-Version", "DRAFT-2026-v1"); + + // Reset request ID counter since initialize used ID 1 + _lastRequestId = 1; + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ServerConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/ServerConformanceTests.cs index 98cc5971a..ea4187a95 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/ServerConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ServerConformanceTests.cs @@ -159,6 +159,34 @@ public async Task RunConformanceTest_HttpCustomHeaderServerValidation() $"Conformance test failed.\n\nStdout:\n{result.Output}\n\nStderr:\n{result.Error}"); } + // SEP-2322 (Multi Round-Trip Requests / IncompleteResult) conformance scenarios. + // The csharp-sdk ConformanceServer surfaces the matching tools/prompts via + // ConformanceServer.Tools.IncompleteResultTools and ConformanceServer.Prompts.IncompleteResultPrompts. + // Each scenario uses the conformance harness's RawMcpSession, which negotiates DRAFT-2026-v1 + // so the csharp-sdk emits InputRequiredResult on the wire. These tests skip until the + // upstream conformance package ships with SEP-2322 scenarios + // (https://github.com/modelcontextprotocol/conformance/pull/188). + [Theory] + [InlineData("incomplete-result-basic-elicitation")] + [InlineData("incomplete-result-basic-sampling")] + [InlineData("incomplete-result-basic-list-roots")] + [InlineData("incomplete-result-request-state")] + [InlineData("incomplete-result-multiple-input-requests")] + [InlineData("incomplete-result-multi-round")] + [InlineData("incomplete-result-missing-input-response")] + [InlineData("incomplete-result-non-tool-request")] + public async Task RunMrtrConformanceTest(string scenario) + { + Assert.SkipWhen(!NodeHelpers.IsNodeInstalled(), "Node.js is not installed. Skipping conformance tests."); + Assert.SkipWhen(!NodeHelpers.HasMrtrScenarios(), "SEP-2322 MRTR conformance scenarios not yet available in the published @modelcontextprotocol/conformance package."); + + var result = await RunConformanceTestsAsync( + $"server --url {fixture.ServerUrl} --scenario {scenario}"); + + Assert.True(result.Success, + $"MRTR conformance test '{scenario}' failed.\n\nStdout:\n{result.Output}\n\nStderr:\n{result.Error}"); + } + private async Task<(bool Success, string Output, string Error)> RunConformanceTestsAsync(string arguments) { var startInfo = NodeHelpers.ConformanceTestStartInfo(arguments); diff --git a/tests/ModelContextProtocol.ConformanceServer/Program.cs b/tests/ModelContextProtocol.ConformanceServer/Program.cs index 017ec235f..f30d58a4d 100644 --- a/tests/ModelContextProtocol.ConformanceServer/Program.cs +++ b/tests/ModelContextProtocol.ConformanceServer/Program.cs @@ -31,6 +31,7 @@ public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvide .WithHttpTransport() .WithDistributedCacheEventStreamStore() .WithTools() + .WithTools() .WithTools([ConformanceTools.CreateJsonSchema202012Tool()]) .WithRequestFilters(filters => filters.AddCallToolFilter(next => async (request, cancellationToken) => { @@ -47,6 +48,7 @@ public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvide return result; })) .WithPrompts() + .WithPrompts() .WithResources() .WithSubscribeToResourcesHandler(async (ctx, ct) => { diff --git a/tests/ModelContextProtocol.ConformanceServer/Prompts/IncompleteResultPrompts.cs b/tests/ModelContextProtocol.ConformanceServer/Prompts/IncompleteResultPrompts.cs new file mode 100644 index 000000000..4dfe6dfb0 --- /dev/null +++ b/tests/ModelContextProtocol.ConformanceServer/Prompts/IncompleteResultPrompts.cs @@ -0,0 +1,68 @@ +#pragma warning disable MCPEXP001 // MRTR (SEP-2322) is experimental. + +using ModelContextProtocol; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.ComponentModel; +using System.Text.Json; + +namespace ConformanceServer.Prompts; + +/// +/// Prompt implementing the SEP-2322 D1 conformance scenario (incomplete-result-non-tool-request), +/// proving that prompts/get can return an just like +/// tools/call. +/// +[McpServerPromptType] +public sealed class IncompleteResultPrompts +{ + [McpServerPrompt(Name = "test_incomplete_result_prompt")] + [Description("SEP-2322 D1: prompts/get returns IncompleteResult until user_context is supplied.")] + public static GetPromptResult IncompleteResultPrompt(RequestContext context) + { + if (context.Params!.InputResponses is { } responses && + responses.TryGetValue("user_context", out var response)) + { + var elicit = response.Deserialize(InputResponse.ElicitResultJsonTypeInfo); + var contextValue = TryReadString(elicit?.Content, "context") ?? "(unknown)"; + return new GetPromptResult + { + Description = "Prompt customized with elicited user context.", + Messages = + [ + new PromptMessage + { + Role = Role.User, + Content = new TextContentBlock { Text = $"Please continue using context: {contextValue}" }, + }, + ], + }; + } + + throw new InputRequiredException( + new Dictionary + { + ["user_context"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "What context should the prompt use?", + RequestedSchema = new ElicitRequestParams.RequestSchema + { + Properties = new Dictionary + { + ["context"] = new ElicitRequestParams.StringSchema(), + }, + Required = ["context"], + }, + }), + }); + } + + private static string? TryReadString(IDictionary? content, string key) + { + if (content is null || !content.TryGetValue(key, out var element)) + { + return null; + } + return element.ValueKind == JsonValueKind.String ? element.GetString() : element.ToString(); + } +} diff --git a/tests/ModelContextProtocol.ConformanceServer/Tools/IncompleteResultTools.cs b/tests/ModelContextProtocol.ConformanceServer/Tools/IncompleteResultTools.cs new file mode 100644 index 000000000..e4f373245 --- /dev/null +++ b/tests/ModelContextProtocol.ConformanceServer/Tools/IncompleteResultTools.cs @@ -0,0 +1,279 @@ +#pragma warning disable MCPEXP001 // MRTR (SEP-2322) is experimental. + +using ModelContextProtocol; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.ComponentModel; +using System.Text.Json; +using System.Text.Json.Nodes; + +namespace ConformanceServer.Tools; + +/// +/// Tools implementing the SEP-2322 (MRTR / IncompleteResult) conformance scenarios from +/// incomplete-result.ts in the conformance test suite. All tools use the +/// API so they work both in stateful sessions with +/// MRTR-aware clients and in legacy-resolve mode (the SDK will translate exceptions to the +/// proper wire shape based on negotiated protocol version). +/// +[McpServerToolType] +public sealed class IncompleteResultTools +{ + // ──── A1: Basic Elicitation ───────────────────────────────────────────── + [McpServerTool(Name = "test_tool_with_elicitation")] + [Description("SEP-2322 A1: returns IncompleteResult with elicitation/create keyed 'user_name'.")] + public static CallToolResult ToolWithElicitation(RequestContext context) + { + if (context.Params!.InputResponses is { } responses && + responses.TryGetValue("user_name", out var response)) + { + var elicit = response.Deserialize(InputResponse.ElicitResultJsonTypeInfo); + var name = TryReadString(elicit?.Content, "name") ?? "world"; + return TextResult($"Hello, {name}!"); + } + + throw new InputRequiredException( + new Dictionary + { + ["user_name"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "What is your name?", + RequestedSchema = new ElicitRequestParams.RequestSchema + { + Properties = new Dictionary + { + ["name"] = new ElicitRequestParams.StringSchema(), + }, + Required = ["name"], + }, + }), + }); + } + + // ──── A2: Basic Sampling ──────────────────────────────────────────────── + [McpServerTool(Name = "test_incomplete_result_sampling")] + [Description("SEP-2322 A2: returns IncompleteResult with sampling/createMessage keyed 'capital_question'.")] + public static CallToolResult ToolWithSampling(RequestContext context) + { + if (context.Params!.InputResponses is { } responses && + responses.TryGetValue("capital_question", out var response)) + { + var text = response.Deserialize(InputResponse.CreateMessageResultJsonTypeInfo)?.Content?.OfType().FirstOrDefault()?.Text ?? "(no text)"; + return TextResult($"Sampling said: {text}"); + } + + throw new InputRequiredException( + new Dictionary + { + ["capital_question"] = InputRequest.ForSampling(new CreateMessageRequestParams + { + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = [new TextContentBlock { Text = "What is the capital of France?" }], + }, + ], + MaxTokens = 100, + }), + }); + } + + // ──── A3: Basic ListRoots ─────────────────────────────────────────────── + [McpServerTool(Name = "test_incomplete_result_list_roots")] + [Description("SEP-2322 A3: returns IncompleteResult with roots/list keyed 'client_roots'.")] + public static CallToolResult ToolWithListRoots(RequestContext context) + { + if (context.Params!.InputResponses is { } responses && + responses.TryGetValue("client_roots", out var response)) + { + var count = response.Deserialize(InputResponse.ListRootsResultJsonTypeInfo)?.Roots?.Count ?? 0; + return TextResult($"Got {count} root(s) from the client."); + } + + throw new InputRequiredException( + new Dictionary + { + ["client_roots"] = InputRequest.ForRootsList(new ListRootsRequestParams()), + }); + } + + // ──── B1: requestState round-trip ─────────────────────────────────────── + private const string RequestStateToken = "mrtr-conformance-state-v1"; + + [McpServerTool(Name = "test_incomplete_result_request_state")] + [Description("SEP-2322 B1: round-trips a requestState string; R2 echoes 'state-ok' on success.")] + public static CallToolResult ToolWithRequestState(RequestContext context) + { + if (context.Params!.RequestState is { } state) + { + if (state != RequestStateToken) + { + return TextResult("state-mismatch: client echoed an unexpected requestState"); + } + return TextResult("state-ok: server received and validated the echoed requestState"); + } + + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["confirm"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "Please confirm", + RequestedSchema = new ElicitRequestParams.RequestSchema + { + Properties = new Dictionary + { + ["ok"] = new ElicitRequestParams.BooleanSchema(), + }, + Required = ["ok"], + }, + }), + }, + requestState: RequestStateToken); + } + + // ──── B2: Multiple input requests in one round ────────────────────────── + [McpServerTool(Name = "test_incomplete_result_multiple_inputs")] + [Description("SEP-2322 B2: returns 3 simultaneous inputRequests (elicit + sampling + roots) plus requestState.")] + public static CallToolResult ToolWithMultipleInputs(RequestContext context) + { + if (context.Params!.InputResponses is { } responses && responses.Count >= 3) + { + return TextResult("multiple-inputs-ok: received elicit + sampling + roots responses"); + } + + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["user_name"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "What is your name?", + RequestedSchema = new ElicitRequestParams.RequestSchema + { + Properties = new Dictionary + { + ["name"] = new ElicitRequestParams.StringSchema(), + }, + Required = ["name"], + }, + }), + ["greeting"] = InputRequest.ForSampling(new CreateMessageRequestParams + { + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = [new TextContentBlock { Text = "Generate a greeting" }], + }, + ], + MaxTokens = 50, + }), + ["client_roots"] = InputRequest.ForRootsList(new ListRootsRequestParams()), + }, + requestState: "multi-input-state"); + } + + // ──── B3: Multi-round (R1 -> incomplete, R2 -> incomplete (new state), R3 -> complete) ───── + [McpServerTool(Name = "test_incomplete_result_multi_round")] + [Description("SEP-2322 B3: three-round flow whose requestState changes between rounds.")] + public static CallToolResult ToolWithMultiRound(RequestContext context) + { + var state = context.Params!.RequestState; + if (state is null) + { + // Round 1: elicit name. + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["step1"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "Step 1: What is your name?", + RequestedSchema = new ElicitRequestParams.RequestSchema + { + Properties = new Dictionary + { + ["name"] = new ElicitRequestParams.StringSchema(), + }, + Required = ["name"], + }, + }), + }, + requestState: "round-1"); + } + + if (state == "round-1") + { + // Round 2: elicit color (new state). + throw new InputRequiredException( + inputRequests: new Dictionary + { + ["step2"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "Step 2: What is your favorite color?", + RequestedSchema = new ElicitRequestParams.RequestSchema + { + Properties = new Dictionary + { + ["color"] = new ElicitRequestParams.StringSchema(), + }, + Required = ["color"], + }, + }), + }, + requestState: "round-2"); + } + + // Round 3: complete. + return TextResult("multi-round-ok"); + } + + // ──── C1: Missing/wrong inputResponses key — re-request rather than error ──── + [McpServerTool(Name = "test_incomplete_result_elicitation")] + [Description("SEP-2322 C1: re-requests missing inputResponses key instead of erroring.")] + public static CallToolResult ToolForMissingResponse(RequestContext context) + { + if (context.Params!.InputResponses is { } responses && + responses.TryGetValue("user_name", out var response)) + { + var elicit = response.Deserialize(InputResponse.ElicitResultJsonTypeInfo); + var name = TryReadString(elicit?.Content, "name") ?? "world"; + return TextResult($"Hello, {name}!"); + } + + // Either no inputResponses or wrong key — re-request via a fresh InputRequiredResult + // (per SEP-2322 recommendation in scenario C1). + throw new InputRequiredException( + new Dictionary + { + ["user_name"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "What is your name?", + RequestedSchema = new ElicitRequestParams.RequestSchema + { + Properties = new Dictionary + { + ["name"] = new ElicitRequestParams.StringSchema(), + }, + Required = ["name"], + }, + }), + }); + } + + private static CallToolResult TextResult(string text) => new() + { + Content = [new TextContentBlock { Text = text }], + }; + + private static string? TryReadString(IDictionary? content, string key) + { + if (content is null || !content.TryGetValue(key, out var element)) + { + return null; + } + return element.ValueKind == JsonValueKind.String ? element.GetString() : element.ToString(); + } +} diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs index 262efbd40..749ef51eb 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs @@ -587,6 +587,14 @@ public async Task ReturnsNegotiatedProtocolVersion(string? protocolVersion) Assert.Equal(protocolVersion ?? "2025-11-25", client.NegotiatedProtocolVersion); } + [Fact] + public async Task ReturnsNegotiatedProtocolVersion_WithExperimentalProtocol() + { + Server.ServerOptions.ProtocolVersion = "DRAFT-2026-v1"; + await using McpClient client = await CreateMcpClientForServer(new() { ProtocolVersion = "DRAFT-2026-v1" }); + Assert.Equal("DRAFT-2026-v1", client.NegotiatedProtocolVersion); + } + [Fact] public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionInvocation_ClientHandlesSamplingWithIChatClient() { diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsMessageFilterTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsMessageFilterTests.cs index 171c6bead..b61022c2b 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsMessageFilterTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsMessageFilterTests.cs @@ -87,8 +87,14 @@ public async Task AddIncomingMessageFilter_Intercepts_Request_Messages() await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); - // The message filter should intercept JsonRpcRequest messages - Assert.Contains("JsonRpcRequest", messageTypes); + // The message filter should intercept JsonRpcRequest messages. + // Use strict counts so a regression that invokes the filter pipeline more than once per + // incoming message (analogous to the SendRequestAsync double-wrap regression on the outgoing + // side) would fail this test instead of slipping through Assert.Contains. + // A single ListToolsAsync drives three server-bound messages: initialize (request), + // notifications/initialized (notification), and tools/list (request). + Assert.Equal(2, messageTypes.Count(m => m == nameof(JsonRpcRequest))); + Assert.Equal(1, messageTypes.Count(m => m == nameof(JsonRpcNotification))); } [Fact] @@ -142,6 +148,13 @@ public async Task AddIncomingMessageFilter_Multiple_Filters_Execute_In_Order() Assert.True(idx1Before < idx2Before); Assert.True(idx2Before < idx2After); Assert.True(idx2After < idx1After); + + // Verify each filter ran exactly once per incoming message (initialize + notifications/initialized + tools/list). + // Strict counts catch regressions where the incoming filter pipeline gets invoked more than once per message. + Assert.Equal(3, logMessages.Count(m => m == "MessageFilter1 before")); + Assert.Equal(3, logMessages.Count(m => m == "MessageFilter2 before")); + Assert.Equal(3, logMessages.Count(m => m == "MessageFilter2 after")); + Assert.Equal(3, logMessages.Count(m => m == "MessageFilter1 after")); } [Fact] @@ -372,15 +385,20 @@ public async Task AddOutgoingMessageFilter_Sees_Responses_Notifications_And_Requ await client.CallToolAsync("sampling-tool", new Dictionary { ["prompt"] = "Hello" }, cancellationToken: TestContext.Current.CancellationToken); + // Exact counts catch regressions where the outgoing filter pipeline gets applied more than once + // per outbound message (e.g., SendRequestAsync double-wrapping SendToRelatedTransportAsync). + Assert.Equal(1, observedMessages.Count(m => m == "initialize")); + Assert.Equal(2, observedMessages.Count(m => m == "progress")); // ProgressTool sends two NotifyProgressAsync calls + Assert.Equal(2, observedMessages.Count(m => m == "response")); // one tool-call response per CallToolAsync + Assert.Equal(1, observedMessages.Count(m => m == $"request:{RequestMethods.SamplingCreateMessage}")); + + // Preserve the original ordering intent: initialize first, then progress, then the final response. int initializeIndex = observedMessages.IndexOf("initialize"); int progressIndex = observedMessages.IndexOf("progress"); int responseIndex = observedMessages.LastIndexOf("response"); - int requestIndex = observedMessages.IndexOf($"request:{RequestMethods.SamplingCreateMessage}"); - Assert.True(initializeIndex >= 0); Assert.True(progressIndex > initializeIndex); Assert.True(responseIndex > progressIndex); - Assert.True(requestIndex >= 0); } [Fact] diff --git a/tests/ModelContextProtocol.Tests/Protocol/MrtrSerializationTests.cs b/tests/ModelContextProtocol.Tests/Protocol/MrtrSerializationTests.cs new file mode 100644 index 000000000..e44f6527c --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Protocol/MrtrSerializationTests.cs @@ -0,0 +1,298 @@ +using ModelContextProtocol.Protocol; +using System.Text.Json; +using System.Text.Json.Nodes; + +namespace ModelContextProtocol.Tests.Protocol; + +public static class MrtrSerializationTests +{ + [Fact] + public static void IncompleteResult_SerializationRoundTrip_PreservesAllProperties() + { + var original = new InputRequiredResult + { + InputRequests = new Dictionary + { + ["input_1"] = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "What is your name?", + RequestedSchema = new() + }), + ["input_2"] = InputRequest.ForSampling(new CreateMessageRequestParams + { + Messages = [new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = "Hello" }] }], + MaxTokens = 100 + }) + }, + RequestState = "correlation-123", + }; + + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal("input_required", deserialized.ResultType); + Assert.Equal("correlation-123", deserialized.RequestState); + Assert.NotNull(deserialized.InputRequests); + Assert.Equal(2, deserialized.InputRequests.Count); + Assert.True(deserialized.InputRequests.ContainsKey("input_1")); + Assert.True(deserialized.InputRequests.ContainsKey("input_2")); + } + + [Fact] + public static void IncompleteResult_HasResultTypeIncomplete() + { + var result = new InputRequiredResult(); + Assert.Equal("input_required", result.ResultType); + } + + [Fact] + public static void IncompleteResult_ResultType_AppearsInJson() + { + var result = new InputRequiredResult + { + RequestState = "abc", + }; + + string json = JsonSerializer.Serialize(result, McpJsonUtilities.DefaultOptions); + var node = JsonNode.Parse(json); + + Assert.NotNull(node); + Assert.Equal("input_required", (string?)node["resultType"]); + Assert.Equal("abc", (string?)node["requestState"]); + } + + [Fact] + public static void InputRequest_ForElicitation_SerializesCorrectly() + { + var inputRequest = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "Enter name", + RequestedSchema = new() + }); + + string json = JsonSerializer.Serialize(inputRequest, McpJsonUtilities.DefaultOptions); + var node = JsonNode.Parse(json); + + Assert.NotNull(node); + Assert.Equal("elicitation/create", (string?)node["method"]); + Assert.NotNull(node["params"]); + Assert.Equal("Enter name", (string?)node["params"]!["message"]); + } + + [Fact] + public static void InputRequest_ForSampling_SerializesCorrectly() + { + var inputRequest = InputRequest.ForSampling(new CreateMessageRequestParams + { + Messages = [new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = "Prompt" }] }], + MaxTokens = 50 + }); + + string json = JsonSerializer.Serialize(inputRequest, McpJsonUtilities.DefaultOptions); + var node = JsonNode.Parse(json); + + Assert.NotNull(node); + Assert.Equal("sampling/createMessage", (string?)node["method"]); + Assert.NotNull(node["params"]); + Assert.Equal(50, (int?)node["params"]!["maxTokens"]); + } + + [Fact] + public static void InputRequest_ForRootsList_SerializesCorrectly() + { + var inputRequest = InputRequest.ForRootsList(new ListRootsRequestParams()); + + string json = JsonSerializer.Serialize(inputRequest, McpJsonUtilities.DefaultOptions); + var node = JsonNode.Parse(json); + + Assert.NotNull(node); + Assert.Equal("roots/list", (string?)node["method"]); + } + + [Fact] + public static void InputRequest_Elicitation_RoundTrip() + { + var original = InputRequest.ForElicitation(new ElicitRequestParams + { + Message = "test message", + RequestedSchema = new() + }); + + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal("elicitation/create", deserialized.Method); + Assert.NotNull(deserialized.ElicitationParams); + Assert.Equal("test message", deserialized.ElicitationParams.Message); + } + + [Fact] + public static void InputRequest_Sampling_RoundTrip() + { + var original = InputRequest.ForSampling(new CreateMessageRequestParams + { + Messages = [new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = "Hello" }] }], + MaxTokens = 200 + }); + + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal("sampling/createMessage", deserialized.Method); + Assert.NotNull(deserialized.SamplingParams); + Assert.Equal(200, deserialized.SamplingParams.MaxTokens); + } + + [Fact] + public static void InputRequest_RootsList_RoundTrip() + { + var original = InputRequest.ForRootsList(new ListRootsRequestParams()); + + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal("roots/list", deserialized.Method); + Assert.NotNull(deserialized.RootsParams); + } + + [Fact] + public static void InputResponse_FromSamplingResult_RoundTrip() + { + var samplingResult = new CreateMessageResult + { + Content = [new TextContentBlock { Text = "Response text" }], + Model = "test-model" + }; + + var inputResponse = InputResponse.FromSamplingResult(samplingResult); + + // Serialize → deserialize + string json = JsonSerializer.Serialize(inputResponse, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + var sampling = deserialized.Deserialize(InputResponse.CreateMessageResultJsonTypeInfo); + Assert.NotNull(sampling); + Assert.Equal("test-model", sampling.Model); + } + + [Fact] + public static void InputResponse_FromElicitResult_RoundTrip() + { + var elicitResult = new ElicitResult + { + Action = "confirm", + Content = new Dictionary + { + ["key"] = JsonDocument.Parse("\"value\"").RootElement.Clone() + } + }; + + var inputResponse = InputResponse.FromElicitResult(elicitResult); + + string json = JsonSerializer.Serialize(inputResponse, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + var elicit = deserialized.Deserialize(InputResponse.ElicitResultJsonTypeInfo); + Assert.NotNull(elicit); + Assert.Equal("confirm", elicit.Action); + } + + [Fact] + public static void InputResponse_FromRootsResult_RoundTrip() + { + var rootsResult = new ListRootsResult + { + Roots = [new Root { Uri = "file:///test", Name = "Test" }] + }; + + var inputResponse = InputResponse.FromRootsResult(rootsResult); + + string json = JsonSerializer.Serialize(inputResponse, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + var roots = deserialized.Deserialize(InputResponse.ListRootsResultJsonTypeInfo); + Assert.NotNull(roots); + Assert.Single(roots.Roots); + Assert.Equal("file:///test", roots.Roots[0].Uri); + } + + [Fact] + public static void InputRequestDictionary_SerializationRoundTrip() + { + IDictionary requests = new Dictionary + { + ["a"] = InputRequest.ForElicitation(new ElicitRequestParams { Message = "q1", RequestedSchema = new() }), + ["b"] = InputRequest.ForSampling(new CreateMessageRequestParams + { + Messages = [new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = "q2" }] }], + MaxTokens = 50 + }), + }; + + string json = JsonSerializer.Serialize(requests, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize>(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal(2, deserialized.Count); + Assert.Equal("elicitation/create", deserialized["a"].Method); + Assert.Equal("sampling/createMessage", deserialized["b"].Method); + } + + [Fact] + public static void InputResponseDictionary_SerializationRoundTrip() + { + IDictionary responses = new Dictionary + { + ["a"] = InputResponse.FromElicitResult(new ElicitResult { Action = "confirm" }), + ["b"] = InputResponse.FromSamplingResult(new CreateMessageResult + { + Content = [new TextContentBlock { Text = "AI" }], + Model = "m1" + }), + }; + + string json = JsonSerializer.Serialize(responses, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize>(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal(2, deserialized.Count); + } + + [Fact] + public static void Result_ResultType_DefaultsToNull() + { + var result = new CallToolResult + { + Content = [new TextContentBlock { Text = "test" }] + }; + + string json = JsonSerializer.Serialize(result, McpJsonUtilities.DefaultOptions); + var node = JsonNode.Parse(json); + + // result_type should not appear for normal results + Assert.Null(node?["resultType"]); + } + + [Fact] + public static void RequestParams_InputResponses_NotSerializedByDefault() + { + var callParams = new CallToolRequestParams + { + Name = "test-tool", + }; + + string json = JsonSerializer.Serialize(callParams, McpJsonUtilities.DefaultOptions); + var node = JsonNode.Parse(json); + + // inputResponses and requestState should not appear when null + Assert.Null(node?["inputResponses"]); + Assert.Null(node?["requestState"]); + } +} diff --git a/tests/ModelContextProtocol.Tests/Server/DraftProtocolBackcompatTests.cs b/tests/ModelContextProtocol.Tests/Server/DraftProtocolBackcompatTests.cs new file mode 100644 index 000000000..2c82679d7 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/DraftProtocolBackcompatTests.cs @@ -0,0 +1,151 @@ +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.Tests.Server; + +/// +/// Verifies that the server-to-client request methods (, +/// , +/// ) keep working when the negotiated protocol revision is +/// DRAFT-2026-v1 on a stateful session — for example, stdio. +/// +/// +/// Under DRAFT-2026-v1 the spec removes the corresponding server-to-client request methods, but +/// the SDK only fails fast in stateless mode (where the existing ThrowIf*Unsupported guards already +/// throw "X is not supported in stateless mode" because is +/// ). Stdio is implicitly stateful — one per process — so the +/// legacy elicitation/create / sampling/createMessage / roots/list flow still works. +/// A future PR is expected to force DRAFT-2026-v1 Streamable HTTP servers to stateless mode, at which +/// point those configurations will start throwing through the existing stateless guard. +/// +public sealed class DraftProtocolBackcompatTests : ClientServerTestBase +{ + public DraftProtocolBackcompatTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper, startServer: false) + { + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + services.Configure(options => + { + options.ProtocolVersion = "DRAFT-2026-v1"; + }); + + mcpServerBuilder.WithTools([ + McpServerTool.Create(ElicitToolAsync, new() { Name = "elicit-tool" }), + McpServerTool.Create(SampleToolAsync, new() { Name = "sample-tool" }), + McpServerTool.Create(RootsToolAsync, new() { Name = "roots-tool" }), + ]); + } + + [Fact] + public async Task ElicitAsync_OnStatefulDraftSession_ResolvesViaLegacyRequest() + { + StartServer(); + await using var client = await CreateMcpClientForServer(new McpClientOptions + { + ProtocolVersion = "DRAFT-2026-v1", + Capabilities = new ClientCapabilities + { + Elicitation = new ElicitationCapability(), + }, + Handlers = new McpClientHandlers + { + ElicitationHandler = (_, _) => new ValueTask(new ElicitResult { Action = "accept" }), + }, + }); + + var result = await client.CallToolAsync("elicit-tool", cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal("elicit-ok:accept", Assert.IsType(result.Content[0]).Text); + } + + [Fact] + public async Task SampleAsync_OnStatefulDraftSession_ResolvesViaLegacyRequest() + { + StartServer(); + await using var client = await CreateMcpClientForServer(new McpClientOptions + { + ProtocolVersion = "DRAFT-2026-v1", + Capabilities = new ClientCapabilities + { + Sampling = new SamplingCapability(), + }, + Handlers = new McpClientHandlers + { + SamplingHandler = (_, _, _) => new ValueTask(new CreateMessageResult + { + Model = "test-model", + Role = Role.Assistant, + Content = [new TextContentBlock { Text = "hello back" }], + }), + }, + }); + + var result = await client.CallToolAsync("sample-tool", cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal("sample-ok:hello back", Assert.IsType(result.Content[0]).Text); + } + + [Fact] + public async Task RequestRootsAsync_OnStatefulDraftSession_ResolvesViaLegacyRequest() + { + StartServer(); + await using var client = await CreateMcpClientForServer(new McpClientOptions + { + ProtocolVersion = "DRAFT-2026-v1", + Capabilities = new ClientCapabilities + { + Roots = new RootsCapability(), + }, + Handlers = new McpClientHandlers + { + RootsHandler = (_, _) => new ValueTask(new ListRootsResult + { + Roots = [new Root { Uri = "file:///home", Name = "home" }], + }), + }, + }); + + var result = await client.CallToolAsync("roots-tool", cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal("roots-ok:file:///home", Assert.IsType(result.Content[0]).Text); + } + + private static async Task ElicitToolAsync(McpServer server, CancellationToken cancellationToken) + { + var elicit = await server.ElicitAsync(new ElicitRequestParams + { + Message = "Need input", + RequestedSchema = new(), + }, cancellationToken); + return $"elicit-ok:{elicit.Action}"; + } + + private static async Task SampleToolAsync(McpServer server, CancellationToken cancellationToken) + { + var sample = await server.SampleAsync(new CreateMessageRequestParams + { + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = [new TextContentBlock { Text = "ping" }], + }, + ], + MaxTokens = 16, + }, cancellationToken); + var text = sample.Content.OfType().FirstOrDefault()?.Text; + return $"sample-ok:{text}"; + } + + private static async Task RootsToolAsync(McpServer server, CancellationToken cancellationToken) + { + var roots = await server.RequestRootsAsync(new ListRootsRequestParams(), cancellationToken); + return $"roots-ok:{roots.Roots.FirstOrDefault()?.Uri}"; + } +} diff --git a/tests/ModelContextProtocol.Tests/Server/MrtrInputRequiredExceptionTests.cs b/tests/ModelContextProtocol.Tests/Server/MrtrInputRequiredExceptionTests.cs new file mode 100644 index 000000000..ac7e38f33 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/MrtrInputRequiredExceptionTests.cs @@ -0,0 +1,61 @@ +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; + +namespace ModelContextProtocol.Tests.Server; + +/// +/// Tests for the MRTR server API — IsMrtrSupported, InputRequiredException, +/// and client auto-retry of incomplete results. +/// +public class MrtrInputRequiredExceptionTests : ClientServerTestBase +{ + private readonly ServerMessageTracker _messageTracker = new(); + + public MrtrInputRequiredExceptionTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper, startServer: false) + { + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + services.Configure(options => + { + options.ProtocolVersion = "DRAFT-2026-v1"; + _messageTracker.AddFilters(options.Filters.Message); + }); + + mcpServerBuilder.WithTools([ + McpServerTool.Create( + static string (McpServer server) => + { + throw new InputRequiredException(requestState: "should-not-work"); + }, + new McpServerToolCreateOptions + { + Name = "always-incomplete", + Description = "Tool that always throws InputRequiredException" + }), + ]); + } + + [Fact] + public async Task InputRequiredException_WithoutInputRequests_ExhaustsRetries() + { + StartServer(); + var clientOptions = new McpClientOptions(); + + await using var client = await CreateMcpClientForServer(clientOptions); + + // The always-incomplete tool throws InputRequiredException with only requestState + // and no inputRequests. The client has nothing to dispatch, so it keeps retrying + // with the same requestState until the retry budget is exhausted. + var exception = await Assert.ThrowsAsync(() => + client.CallToolAsync("always-incomplete", + cancellationToken: TestContext.Current.CancellationToken).AsTask()); + + Assert.Contains("more than", exception.Message); + } +}