From b880950a26ca758a2a9830a27ec63d8935a58ea6 Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Fri, 5 Jun 2026 21:33:48 +0100 Subject: [PATCH 01/17] Deflake the session-level timeout test with trio's virtual clock (#2788) --- tests/interaction/lowlevel/test_timeouts.py | 24 ++++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/interaction/lowlevel/test_timeouts.py b/tests/interaction/lowlevel/test_timeouts.py index 2a3b885a6..c80e98405 100644 --- a/tests/interaction/lowlevel/test_timeouts.py +++ b/tests/interaction/lowlevel/test_timeouts.py @@ -1,8 +1,9 @@ """Request timeouts against the low-level Server, driven through the public client API. The handler blocks on an event that is never set, so the awaited response can never arrive and -any positive timeout fires deterministically on the next event-loop pass. The timeout is therefore -set to an effectively-zero duration: the tests add no wall-clock time to the suite. (Zero itself +any positive timeout fires deterministically on the next event-loop pass. Per-request timeouts are +set to an effectively-zero duration; the session-level test runs on trio's virtual clock instead +(see the comment there). Either way the tests add no wall-clock time to the suite. (Zero itself cannot be used: a falsy read_timeout_seconds is silently treated as "no timeout".) """ @@ -12,6 +13,7 @@ import anyio import pytest from inline_snapshot import snapshot +from trio.testing import MockClock from mcp import McpError, types from mcp.server.lowlevel import Server @@ -82,7 +84,19 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentB assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="still alive")])) +# A session-level timeout cannot use the effectively-zero pattern above: it also governs the +# initialize handshake, which must complete before the blocked tool call can wait the timeout +# out in full. Any real-clock margin is a bet against CI scheduler stalls (a 50ms value lost +# that bet in CI; the in-process handshake tail reaches ~190ms on a loaded windows runner), so +# this test runs on trio's virtual clock instead. With autojump, time advances only when every +# task is blocked: the handshake always has a runnable task and therefore cannot time out no +# matter how slow the runner, and once the tool call blocks on the never-answered request the +# run goes idle and the clock jumps straight to the deadline — deterministic, with no real wait. @requirement("protocol:timeout:session-default") +@pytest.mark.parametrize( + "anyio_backend", + [pytest.param(("trio", {"clock": MockClock(autojump_threshold=0)}), id="trio-mockclock")], +) async def test_session_level_timeout_applies_to_every_request(connect: Connect) -> None: """A read timeout configured on the client applies to requests that do not set their own.""" server: Server[Any] = Server("blocker") @@ -93,12 +107,6 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentB await anyio.Event().wait() # blocks until the session is torn down raise NotImplementedError # unreachable - # The one real wall-clock wait in the suite, and it cannot be made effectively zero like the - # per-request timeouts: a session-level timeout also governs the initialize handshake, so the - # value must be long enough for the in-process handshake to complete before the blocked tool - # call waits it out in full. 50ms buys a ~50x safety margin over the handshake's actual - # latency; lowering it only erodes the margin against CI scheduler jitter without saving - # anything perceptible. async with connect(server, read_timeout_seconds=timedelta(seconds=0.05)) as client: with pytest.raises(McpError) as exc_info: await client.call_tool("block", {}) From ea6a479e41b8030728c5960e96688bc8d2552155 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 08:34:57 +0000 Subject: [PATCH 02/17] tests: avoid abandoned-async-generator warnings under the trio backend The session-level timeout test now runs on trio's virtual clock. At teardown the streamable-HTTP client abandons its httpx/httpx-sse response generators; trio's asyncgen finalizer warns about each one (asyncio finalizes abandoned generators silently at loop shutdown), and filterwarnings=error turns that into a test failure. Scope-ignore the two third-party generator signatures, and make the bridge's response body delegate to the memory stream's own iterator so the harness itself leaves no abandoned generator. --- tests/interaction/conftest.py | 13 +++++++++++++ tests/interaction/transports/_bridge.py | 11 ++++++++--- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/tests/interaction/conftest.py b/tests/interaction/conftest.py index 597a87082..a6dd4d797 100644 --- a/tests/interaction/conftest.py +++ b/tests/interaction/conftest.py @@ -23,6 +23,19 @@ def pytest_configure(config: pytest.Config) -> None: "filterwarnings", "ignore:.*MemoryObject(Send|Receive)Stream:pytest.PytestUnraisableExceptionWarning" ) config.addinivalue_line("filterwarnings", "ignore:.*MemoryObject(Send|Receive)Stream:ResourceWarning") + # The trio-mockclock leg of the session-level timeout test (test_timeouts.py) is the suite's + # only test on the trio backend. v1's streamable-HTTP client abandons its httpx/httpx-sse + # response generators when the session task group is cancelled at teardown; asyncio finalizes + # abandoned async generators silently at loop shutdown, but trio's finalizer warns about each + # one (`Async generator ... was garbage collected before it had been exhausted`). The fixes + # live in `src/` on `main` and are out of scope for this tests-only backport. The filters are + # scoped to the two known httpx generator signatures so an unrelated leak still fails the suite. + config.addinivalue_line("filterwarnings", "ignore:Async generator 'httpx:ResourceWarning") + config.addinivalue_line( + "filterwarnings", + "ignore:.*async_generator object (Response.aiter_text|EventSource.aiter_sse)" + ":pytest.PytestUnraisableExceptionWarning", + ) _FACTORIES: dict[str, Connect] = { diff --git a/tests/interaction/transports/_bridge.py b/tests/interaction/transports/_bridge.py index f78c6d14b..58274f60c 100644 --- a/tests/interaction/transports/_bridge.py +++ b/tests/interaction/transports/_bridge.py @@ -47,9 +47,14 @@ def __init__(self, chunks: MemoryObjectReceiveStream[bytes], client_disconnected self._chunks = chunks self._client_disconnected = client_disconnected - async def __aiter__(self) -> AsyncIterator[bytes]: - async for chunk in self._chunks: - yield chunk + def __aiter__(self) -> AsyncIterator[bytes]: + # Delegate to the memory stream's own async iterator instead of wrapping it in an async + # generator. httpx abandons the iterator without closing it when a streamed response is + # closed mid-stream; trio's asyncgen finalizer warns about abandoned generators (asyncio + # finalizes them silently at loop shutdown), which would fail the suite's one trio-backend + # test. The memory stream is a plain async iterator with the same EndOfStream -> + # StopAsyncIteration semantics and is not tracked by that machinery. + return self._chunks async def aclose(self) -> None: self._client_disconnected.set() From 8d7389602750dbfcb4a3cd8bc16c7b6b54598c4a Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 12:45:25 +0000 Subject: [PATCH 03/17] tests: re-export StreamingASGITransport as the sanctioned bridge import point --- tests/interaction/transports/__init__.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/interaction/transports/__init__.py b/tests/interaction/transports/__init__.py index e69de29bb..b5bbb633c 100644 --- a/tests/interaction/transports/__init__.py +++ b/tests/interaction/transports/__init__.py @@ -0,0 +1,9 @@ +"""Transport-specific interaction tests, and the in-process streaming bridge they are built on. + +`StreamingASGITransport` is re-exported here as the sanctioned import point for test code +outside this suite (the bridge module itself is suite-private). +""" + +from tests.interaction.transports._bridge import StreamingASGITransport + +__all__ = ["StreamingASGITransport"] From 6a6e7b7276bde130c3bf7825890e1221a0e08075 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 12:54:30 +0000 Subject: [PATCH 04/17] Run transport security tests in process instead of over sockets (#2764) Tests-only backport to v1.x; adapted from main commit b3025f9. --- tests/server/test_sse_security.py | 308 ++++++----------- tests/server/test_streamable_http_manager.py | 30 +- tests/server/test_streamable_http_security.py | 319 +++++------------- 3 files changed, 202 insertions(+), 455 deletions(-) diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 716a308a5..deb36e5f2 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -1,16 +1,13 @@ """Tests for SSE server request validation.""" import logging -import multiprocessing import re -import socket from collections.abc import Iterator from typing import Any import anyio import httpx import pytest -import uvicorn from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response @@ -22,12 +19,15 @@ from mcp.server.auth.provider import AccessToken from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings -from mcp.types import Tool -from tests.test_helpers import wait_for_server +from tests.interaction.transports import StreamingASGITransport logger = logging.getLogger(__name__) SERVER_NAME = "test_sse_security_server" +# The in-process app is mounted at this origin purely so URLs are well-formed and the default +# Host header is a localhost form; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" + @pytest.fixture(autouse=True) def reset_sse_starlette_exit_event() -> Iterator[None]: @@ -46,275 +46,161 @@ def clear() -> None: clear() -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: # pragma: no cover - return f"http://127.0.0.1:{server_port}" - - -class SecurityTestServer(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) - - async def on_list_tools(self) -> list[Tool]: - return [] - - -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover - """Run the SSE server with specified security settings.""" - app = SecurityTestServer() +def sse_security_client(security_settings: TransportSecuritySettings | None = None) -> httpx.AsyncClient: + """An httpx client whose requests are served in process by an SSE app with the given settings.""" + server = Server(SERVER_NAME) sse_transport = SseServerTransport("/messages/", security_settings) - async def handle_sse(request: Request): + async def handle_sse(request: Request) -> Response: try: - async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams: - if streams: - await app.run(streams[0], streams[1], app.create_initialization_options()) + async with sse_transport.connect_sse(request.scope, request.receive, request._send) as (read, write): + await server.run(read, write, server.create_initialization_options()) except ValueError as e: - # Validation error was already handled inside connect_sse + # Validation error was already handled inside connect_sse, which sent the rejection + # response itself; its non-empty body checkpoints, so the test reads the rejection + # status before the trailing Response() below sends a second response start. logger.debug(f"SSE connection failed validation: {e}") return Response() - routes = [ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse_transport.handle_post_message), - ] - - starlette_app = Starlette(routes=routes) - uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") - - -def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): - """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) - process.start() - # Wait for server to be ready to accept connections - wait_for_server(port) - return process + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse_transport.handle_post_message), + ] + ) + # The SSE GET runs until it observes a disconnect, so the bridge must let the application + # drain on close rather than cancelling it. + transport = StreamingASGITransport(app, cancel_on_close=False) + return httpx.AsyncClient(transport=transport, base_url=BASE_URL) @pytest.mark.anyio -async def test_sse_security_default_settings(server_port: int): - """Test SSE with default security settings (protection disabled).""" - process = start_server_process(server_port) +async def test_sse_security_default_settings() -> None: + """With default security settings (protection disabled), any Host and Origin connect.""" + headers = {"Host": "evil.com", "Origin": "http://evil.com"} - try: - headers = {"Host": "evil.com", "Origin": "http://evil.com"} - - async with httpx.AsyncClient(timeout=5.0) as client: - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - assert response.status_code == 200 - finally: - process.terminate() - process.join() + async with sse_security_client() as client: + async with client.stream("GET", "/sse", headers=headers) as response: + assert response.status_code == 200 @pytest.mark.anyio -async def test_sse_security_invalid_host_header(server_port: int): - """Test SSE with invalid Host header.""" - # Enable security by providing settings with an empty allowed_hosts list +async def test_sse_security_invalid_host_header() -> None: + """A Host header outside allowed_hosts is rejected with 421.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"]) - process = start_server_process(server_port, security_settings) - try: - # Test with invalid host header - headers = {"Host": "evil.com"} - - async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - finally: - process.terminate() - process.join() + async with sse_security_client(security_settings) as client: + response = await client.get("/sse", headers={"Host": "evil.com"}) + assert response.status_code == 421 + assert response.text == "Invalid Host header" @pytest.mark.anyio -async def test_sse_security_invalid_origin_header(server_port: int): - """Test SSE with invalid Origin header.""" - # Configure security to allow the host but restrict origins +async def test_sse_security_invalid_origin_header() -> None: + """An Origin header outside allowed_origins is rejected with 403.""" security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://localhost:*"] ) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid origin header - headers = {"Origin": "http://evil.com"} - async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) - assert response.status_code == 403 - assert response.text == "Invalid Origin header" - - finally: - process.terminate() - process.join() + async with sse_security_client(security_settings) as client: + response = await client.get("/sse", headers={"Origin": "http://evil.com"}) + assert response.status_code == 403 + assert response.text == "Invalid Origin header" @pytest.mark.anyio -async def test_sse_security_post_invalid_content_type(server_port: int): - """Test POST endpoint with invalid Content-Type header.""" - # Configure security to allow the host +async def test_sse_security_post_invalid_content_type() -> None: + """A POST whose Content-Type is not application/json (or is missing) is rejected with 400.""" security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) - - try: - async with httpx.AsyncClient(timeout=5.0) as client: - # Test POST with invalid content type - fake_session_id = "12345678123456781234567812345678" - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", - headers={"Content-Type": "text/plain"}, - content="test", - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" + fake_session_id = "12345678123456781234567812345678" - # Test POST with missing content type - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", content="test" - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" + async with sse_security_client(security_settings) as client: + response = await client.post( + f"/messages/?session_id={fake_session_id}", + headers={"Content-Type": "text/plain"}, + content="test", + ) + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" - finally: - process.terminate() - process.join() + response = await client.post(f"/messages/?session_id={fake_session_id}", content="test") + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" @pytest.mark.anyio -async def test_sse_security_disabled(server_port: int): - """Test SSE with security disabled.""" +async def test_sse_security_disabled() -> None: + """With protection explicitly disabled, a disallowed Host still connects.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) - process = start_server_process(server_port, settings) - - try: - # Test with invalid host header - should still work - headers = {"Host": "evil.com"} - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully even with invalid host - assert response.status_code == 200 - - finally: - process.terminate() - process.join() + async with sse_security_client(settings) as client: + async with client.stream("GET", "/sse", headers={"Host": "evil.com"}) as response: + assert response.status_code == 200 @pytest.mark.anyio -async def test_sse_security_custom_allowed_hosts(server_port: int): - """Test SSE with custom allowed hosts.""" +async def test_sse_security_custom_allowed_hosts() -> None: + """A custom entry in allowed_hosts connects; hosts outside the list are still rejected.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost", "127.0.0.1", "custom.host"], allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], ) - process = start_server_process(server_port, settings) - - try: - # Test with custom allowed host - headers = {"Host": "custom.host"} - - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with custom host - assert response.status_code == 200 - # Test with non-allowed host - headers = {"Host": "evil.com"} + async with sse_security_client(settings) as client: + async with client.stream("GET", "/sse", headers={"Host": "custom.host"}) as response: + assert response.status_code == 200 - async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - finally: - process.terminate() - process.join() + response = await client.get("/sse", headers={"Host": "evil.com"}) + assert response.status_code == 421 + assert response.text == "Invalid Host header" @pytest.mark.anyio -async def test_sse_security_wildcard_ports(server_port: int): - """Test SSE with wildcard port patterns.""" +async def test_sse_security_wildcard_ports() -> None: + """A `host:*` pattern accepts that host with any port, for Host and Origin alike.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost:*", "127.0.0.1:*"], allowed_origins=["http://localhost:*", "http://127.0.0.1:*"], ) - process = start_server_process(server_port, settings) - try: - # Test with various port numbers + async with sse_security_client(settings) as client: for test_port in [8080, 3000, 9999]: - headers = {"Host": f"localhost:{test_port}"} - - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with any port - assert response.status_code == 200 - - headers = {"Origin": f"http://localhost:{test_port}"} - - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with any port - assert response.status_code == 200 + async with client.stream("GET", "/sse", headers={"Host": f"localhost:{test_port}"}) as response: + assert response.status_code == 200 - finally: - process.terminate() - process.join() + async with client.stream("GET", "/sse", headers={"Origin": f"http://localhost:{test_port}"}) as response: + assert response.status_code == 200 @pytest.mark.anyio -async def test_sse_security_post_valid_content_type(server_port: int): - """Test POST endpoint with valid Content-Type headers.""" - # Configure security to allow the host +async def test_sse_security_post_valid_content_type() -> None: + """Every application/json Content-Type variant passes validation (reaching the session lookup).""" security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) - - try: - async with httpx.AsyncClient() as client: - # Test with various valid content types - valid_content_types = [ - "application/json", - "application/json; charset=utf-8", - "application/json;charset=utf-8", - "APPLICATION/JSON", # Case insensitive - ] - - for content_type in valid_content_types: - # Use a valid UUID format (even though session won't exist) - fake_session_id = "12345678123456781234567812345678" - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", - headers={"Content-Type": content_type}, - json={"test": "data"}, - ) - # Will get 404 because session doesn't exist, but that's OK - # We're testing that it passes the content-type check - assert response.status_code == 404 - assert response.text == "Could not find session" - - finally: - process.terminate() - process.join() + valid_content_types = [ + "application/json", + "application/json; charset=utf-8", + "application/json;charset=utf-8", + "APPLICATION/JSON", # Case insensitive + ] + # A well-formed session ID that no live session owns. + fake_session_id = "12345678123456781234567812345678" + + async with sse_security_client(security_settings) as client: + for content_type in valid_content_types: + response = await client.post( + f"/messages/?session_id={fake_session_id}", + headers={"Content-Type": content_type}, + json={"test": "data"}, + ) + # 404 proves the request passed the content-type check and reached the session lookup. + assert response.status_code == 404 + assert response.text == "Could not find session" def _authenticated_user(client_id: str, subject: str | None = None, issuer: str | None = None) -> AuthenticatedUser: diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 0ae07c43a..e5b971087 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -1,6 +1,7 @@ """Tests for StreamableHTTPSessionManager.""" import json +import logging from typing import Any from unittest.mock import AsyncMock, patch @@ -317,12 +318,33 @@ async def mock_receive(): assert error_data["error"]["message"] == "Session not found" +class _IdleTimeoutObserver(logging.Handler): + """Resolves `reaped` when the manager logs that a session's idle timeout fired.""" + + def __init__(self) -> None: + super().__init__() + self.reaped = anyio.Event() + + def emit(self, record: logging.LogRecord) -> None: + if "idle timeout" in record.getMessage(): + self.reaped.set() + + @pytest.mark.anyio -async def test_idle_session_is_reaped(): +async def test_idle_session_is_reaped(caplog: pytest.LogCaptureFixture, request: pytest.FixtureRequest): """After idle timeout fires, the session returns 404.""" app = Server("test-idle-reap") manager = StreamableHTTPSessionManager(app=app, session_idle_timeout=0.05) + # The reap is observed through the manager's own "idle timeout" log record: the manager pops + # the session synchronously after emitting it, before its next await, so a waiter woken by + # the record always finds the session gone. caplog.set_level enables INFO so it is created. + observer = _IdleTimeoutObserver() + manager_logger = logging.getLogger(streamable_http_manager.__name__) + manager_logger.addHandler(observer) + request.addfinalizer(lambda: manager_logger.removeHandler(observer)) + caplog.set_level(logging.INFO, logger=streamable_http_manager.__name__) + async with manager.run(): sent_messages: list[Message] = [] @@ -353,8 +375,10 @@ async def mock_receive(): # pragma: no cover assert session_id is not None, "Session ID not found in response headers" - # Wait for the 50ms idle timeout to fire and cleanup to complete - await anyio.sleep(0.1) + # Wait for the 50ms idle timeout to fire and the session to be unregistered. Re-requesting + # the session to poll for the 404 would push its idle deadline forward and keep it alive. + with anyio.fail_after(5): + await observer.reaped.wait() # Verify via public API: old session ID now returns 404 response_messages: list[Message] = [] diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index a637b1dce..f13bb4a9b 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -1,293 +1,130 @@ """Tests for StreamableHTTP server DNS rebinding protection.""" -import logging -import multiprocessing -import socket -from collections.abc import AsyncGenerator +from collections.abc import AsyncIterator from contextlib import asynccontextmanager import httpx import pytest -import uvicorn from starlette.applications import Starlette from starlette.routing import Mount -from starlette.types import Receive, Scope, Send from mcp.server import Server from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.types import Tool -from tests.test_helpers import wait_for_server +from tests.interaction.transports import StreamingASGITransport -logger = logging.getLogger(__name__) SERVER_NAME = "test_streamable_http_security_server" +# The in-process app is mounted at this origin purely so URLs are well-formed and the default +# Host header is a localhost form; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] +@asynccontextmanager +async def streamable_http_security_client( + security_settings: TransportSecuritySettings | None = None, +) -> AsyncIterator[httpx.AsyncClient]: + """Yield an httpx client served in process by a StreamableHTTP app with the given settings.""" + session_manager = StreamableHTTPSessionManager(app=Server(SERVER_NAME), security_settings=security_settings) + app = Starlette(routes=[Mount("/", app=session_manager.handle_request)]) -@pytest.fixture -def server_url(server_port: int) -> str: # pragma: no cover - return f"http://127.0.0.1:{server_port}" + async with session_manager.run(): + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as client: + yield client -class SecurityTestServer(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) +def _base_headers() -> dict[str, str]: + """Headers every well-formed request carries, so each test varies only the header under test.""" + return {"Accept": "application/json, text/event-stream", "Content-Type": "application/json"} - async def on_list_tools(self) -> list[Tool]: - return [] - -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover - """Run the StreamableHTTP server with specified security settings.""" - app = SecurityTestServer() - - # Create session manager with security settings - session_manager = StreamableHTTPSessionManager( - app=app, - json_response=False, - stateless=False, - security_settings=security_settings, - ) - - # Create the ASGI handler - async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: - await session_manager.handle_request(scope, receive, send) - - # Create Starlette app with lifespan - @asynccontextmanager - async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: - async with session_manager.run(): - yield - - routes = [ - Mount("/", app=handle_streamable_http), - ] - - starlette_app = Starlette(routes=routes, lifespan=lifespan) - uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") - - -def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): - """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) - process.start() - # Wait for server to be ready to accept connections - wait_for_server(port) - return process +def _initialize_body() -> dict[str, object]: + """A minimal initialize POST body; these tests assert header validation, not the handshake.""" + return {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} @pytest.mark.anyio -async def test_streamable_http_security_default_settings(server_port: int): - """Test StreamableHTTP with default security settings (protection enabled).""" - process = start_server_process(server_port) - - try: - # Test with valid localhost headers - async with httpx.AsyncClient(timeout=5.0) as client: - # POST request to initialize session - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - ) - assert response.status_code == 200 - assert "mcp-session-id" in response.headers - - finally: - process.terminate() - process.join() +async def test_streamable_http_security_default_settings() -> None: + """With default security settings, a request with localhost headers is served.""" + async with streamable_http_security_client() as client: + response = await client.post("/", json=_initialize_body(), headers=_base_headers()) + assert response.status_code == 200 + assert "mcp-session-id" in response.headers @pytest.mark.anyio -async def test_streamable_http_security_invalid_host_header(server_port: int): - """Test StreamableHTTP with invalid Host header.""" +async def test_streamable_http_security_invalid_host_header() -> None: + """A Host header outside allowed_hosts is rejected with 421.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid host header - headers = { - "Host": "evil.com", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(security_settings) as client: + response = await client.post("/", json=_initialize_body(), headers=_base_headers() | {"Host": "evil.com"}) + assert response.status_code == 421 + assert response.text == "Invalid Host header" @pytest.mark.anyio -async def test_streamable_http_security_invalid_origin_header(server_port: int): - """Test StreamableHTTP with invalid Origin header.""" +async def test_streamable_http_security_invalid_origin_header() -> None: + """An Origin header outside allowed_origins is rejected with 403.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"]) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid origin header - headers = { - "Origin": "http://evil.com", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - assert response.status_code == 403 - assert response.text == "Invalid Origin header" - - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(security_settings) as client: + response = await client.post( + "/", json=_initialize_body(), headers=_base_headers() | {"Origin": "http://evil.com"} + ) + assert response.status_code == 403 + assert response.text == "Invalid Origin header" @pytest.mark.anyio -async def test_streamable_http_security_invalid_content_type(server_port: int): - """Test StreamableHTTP POST with invalid Content-Type header.""" - process = start_server_process(server_port) - - try: - async with httpx.AsyncClient(timeout=5.0) as client: - # Test POST with invalid content type - response = await client.post( - f"http://127.0.0.1:{server_port}/", - headers={ - "Content-Type": "text/plain", - "Accept": "application/json, text/event-stream", - }, - content="test", - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" - - # Test POST with missing content type - response = await client.post( - f"http://127.0.0.1:{server_port}/", - headers={"Accept": "application/json, text/event-stream"}, - content="test", - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" - - finally: - process.terminate() - process.join() +async def test_streamable_http_security_invalid_content_type() -> None: + """A POST whose Content-Type is not application/json (or is missing) is rejected with 400.""" + async with streamable_http_security_client() as client: + response = await client.post("/", headers=_base_headers() | {"Content-Type": "text/plain"}, content="test") + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" + + response = await client.post("/", headers={"Accept": "application/json, text/event-stream"}, content="test") + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" @pytest.mark.anyio -async def test_streamable_http_security_disabled(server_port: int): - """Test StreamableHTTP with security disabled.""" +async def test_streamable_http_security_disabled() -> None: + """With protection explicitly disabled, a disallowed Host is still served.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) - process = start_server_process(server_port, settings) - - try: - # Test with invalid host header - should still work - headers = { - "Host": "evil.com", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - # Should connect successfully even with invalid host - assert response.status_code == 200 - - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(settings) as client: + response = await client.post("/", json=_initialize_body(), headers=_base_headers() | {"Host": "evil.com"}) + assert response.status_code == 200 @pytest.mark.anyio -async def test_streamable_http_security_custom_allowed_hosts(server_port: int): - """Test StreamableHTTP with custom allowed hosts.""" +async def test_streamable_http_security_custom_allowed_hosts() -> None: + """A custom entry in allowed_hosts is served.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost", "127.0.0.1", "custom.host"], allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], ) - process = start_server_process(server_port, settings) - - try: - # Test with custom allowed host - headers = { - "Host": "custom.host", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - # Should connect successfully with custom host - assert response.status_code == 200 - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(settings) as client: + response = await client.post("/", json=_initialize_body(), headers=_base_headers() | {"Host": "custom.host"}) + assert response.status_code == 200 @pytest.mark.anyio -async def test_streamable_http_security_get_request(server_port: int): - """Test StreamableHTTP GET request with security.""" +async def test_streamable_http_security_get_request() -> None: + """GET requests pass the same Host validation before any session handling.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1"]) - process = start_server_process(server_port, security_settings) - - try: - # Test GET request with invalid host header - headers = { - "Host": "evil.com", - "Accept": "text/event-stream", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - # Test GET request with valid host header - headers = { - "Host": "127.0.0.1", - "Accept": "text/event-stream", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - # GET requests need a session ID in StreamableHTTP - # So it will fail with "Missing session ID" not security error - response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) - # This should pass security but fail on session validation - assert response.status_code == 400 - body = response.json() - assert "Missing session ID" in body["error"]["message"] - - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(security_settings) as client: + response = await client.get("/", headers={"Accept": "text/event-stream", "Host": "evil.com"}) + assert response.status_code == 421 + assert response.text == "Invalid Host header" + + response = await client.get("/", headers={"Accept": "text/event-stream", "Host": "127.0.0.1"}) + # An allowed host passes security and fails on session validation instead. + assert response.status_code == 400 + body = response.json() + assert "Missing session ID" in body["error"]["message"] From fd71a106b16a90eff78ed3471ce2b1c8936dd9ac Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 13:08:42 +0000 Subject: [PATCH 05/17] Filter known anyio stream teardown warnings in streamable HTTP security tests The three tests that complete the initialize handshake leak anyio memory streams at transport teardown when run in process. The scoped filters in tests/interaction/conftest.py cover full-suite runs, but they only load when that package is collected, so a targeted run of this file alone turns the ResourceWarning into a failure under filterwarnings=error. Mirror the module-level filterwarnings marks already used in tests/shared/test_sse.py. --- tests/server/test_streamable_http_security.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index f13bb4a9b..99bb4eaec 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -19,6 +19,21 @@ # Host header is a localhost form; nothing listens here. BASE_URL = "http://127.0.0.1:8000" +# v1's streamable-HTTP server transport leaks a handful of anyio memory streams on teardown when +# run in process; the old subprocess harness never observed them. The interaction suite registers +# the same two scoped filters globally from tests/interaction/conftest.py (see the comment there), +# but they only take effect when that package's conftest is loaded; these markers keep the tests +# that complete the initialize handshake passing in isolated runs. Markers are item-scoped, so +# they cannot cover the GC flush at session cleanup: an isolated run without xdist (`-n 0`) still +# exits nonzero after all tests pass. The default xdist runs (addopts has `-n auto`) are +# unaffected, as are full-suite runs, where the interaction conftest's ini-level filters apply. +# The filters are scoped to anyio's MemoryObject*Stream leak signature so an unrelated leak +# still fails the suite. +pytestmark = [ + pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:pytest.PytestUnraisableExceptionWarning"), + pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:ResourceWarning"), +] + @asynccontextmanager async def streamable_http_security_client( From 92c4fe063859d2cbd75e96fe79fa5bd9e66ae171 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 12:48:59 +0000 Subject: [PATCH 06/17] Run SSE and Unicode transport tests in process instead of over sockets (#2765) Tests-only backport to v1.x; adapted from main commit ed39e73. --- tests/client/test_http_unicode.py | 270 ++++++--------- tests/shared/test_sse.py | 540 +++++++++++++----------------- 2 files changed, 334 insertions(+), 476 deletions(-) diff --git a/tests/client/test_http_unicode.py b/tests/client/test_http_unicode.py index ec38f3583..dccba79de 100644 --- a/tests/client/test_http_unicode.py +++ b/tests/client/test_http_unicode.py @@ -5,15 +5,36 @@ (server→client and client→server) using the streamable HTTP transport. """ -import multiprocessing -import socket -from collections.abc import Generator +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any +import httpx import pytest +from starlette.applications import Starlette +from starlette.routing import Mount +import mcp.types as types from mcp.client.session import ClientSession from mcp.client.streamable_http import streamable_http_client -from tests.test_helpers import wait_for_server +from mcp.server import Server +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.types import CallToolResult, TextContent, Tool +from tests.interaction.transports import StreamingASGITransport + +# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" + +# v1's streamable-HTTP server transport leaks a handful of anyio memory streams on teardown when +# run in process; the old subprocess harness never observed them. The interaction suite registers +# the same two scoped filters globally from tests/interaction/conftest.py (see the comment there), +# but they only take effect when that package's conftest is loaded; these markers keep this file +# self-contained for isolated runs. The filters are scoped to anyio's MemoryObject*Stream leak +# signature so an unrelated leak still fails the suite. +pytestmark = [ + pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:pytest.PytestUnraisableExceptionWarning"), + pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:ResourceWarning"), +] # Test constants with various Unicode characters UNICODE_TEST_STRINGS = { @@ -35,28 +56,12 @@ } -def run_unicode_server(port: int) -> None: # pragma: no cover - """Run the Unicode test server in a separate process.""" - # Import inside the function since this runs in a separate process - from collections.abc import AsyncGenerator - from contextlib import asynccontextmanager - from typing import Any - - import uvicorn - from starlette.applications import Starlette - from starlette.routing import Mount - - import mcp.types as types - from mcp.server import Server - from mcp.server.streamable_http_manager import StreamableHTTPSessionManager - from mcp.types import TextContent, Tool - - # Need to recreate the server setup in this process - server = Server(name="unicode_test_server") +def make_unicode_server() -> Server[object, object]: + """The Unicode echo server: tool and prompt contents that exercise non-ASCII round trips.""" + server: Server[object, object] = Server(name="unicode_test_server") @server.list_tools() - async def list_tools() -> list[Tool]: - """List tools with Unicode descriptions.""" + async def handle_list_tools() -> list[Tool]: return [ Tool( name="echo_unicode", @@ -72,22 +77,12 @@ async def list_tools() -> list[Tool]: ] @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any] | None) -> list[TextContent]: - """Handle tool calls with Unicode content.""" - if name == "echo_unicode": - text = arguments.get("text", "") if arguments else "" - return [ - TextContent( - type="text", - text=f"Echo: {text}", - ) - ] - else: - raise ValueError(f"Unknown tool: {name}") + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "echo_unicode" + return CallToolResult(content=[TextContent(type="text", text=f"Echo: {arguments['text']}")]) @server.list_prompts() - async def list_prompts() -> list[types.Prompt]: - """List prompts with Unicode names and descriptions.""" + async def handle_list_prompts() -> list[types.Prompt]: return [ types.Prompt( name="unicode_prompt", @@ -97,137 +92,90 @@ async def list_prompts() -> list[types.Prompt]: ] @server.get_prompt() - async def get_prompt(name: str, arguments: dict[str, Any] | None) -> types.GetPromptResult: - """Get a prompt with Unicode content.""" - if name == "unicode_prompt": - return types.GetPromptResult( - messages=[ - types.PromptMessage( - role="user", - content=types.TextContent( - type="text", - text="Hello世界🌍Привет안녕مرحباשלום", - ), - ) - ] - ) - raise ValueError(f"Unknown prompt: {name}") - - # Create the session manager - session_manager = StreamableHTTPSessionManager( - app=server, - json_response=False, # Use SSE for testing - ) - - @asynccontextmanager - async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: - async with session_manager.run(): - yield - - # Create an ASGI application - app = Starlette( - debug=True, - routes=[ - Mount("/mcp", app=session_manager.handle_request), - ], - lifespan=lifespan, - ) - - # Run the server - config = uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="error", - ) - uvicorn_server = uvicorn.Server(config) - uvicorn_server.run() - - -@pytest.fixture -def unicode_server_port() -> int: - """Find an available port for the Unicode test server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def running_unicode_server(unicode_server_port: int) -> Generator[str, None, None]: - """Start a Unicode test server in a separate process.""" - proc = multiprocessing.Process(target=run_unicode_server, kwargs={"port": unicode_server_port}, daemon=True) - proc.start() - - # Wait for server to be ready - wait_for_server(unicode_server_port) - - try: - yield f"http://127.0.0.1:{unicode_server_port}" - finally: - # Clean up - try graceful termination first - proc.terminate() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - proc.kill() - proc.join(timeout=1) + async def handle_get_prompt(name: str, arguments: dict[str, str] | None) -> types.GetPromptResult: + assert name == "unicode_prompt" + return types.GetPromptResult( + messages=[ + types.PromptMessage( + role="user", + content=types.TextContent(type="text", text="Hello世界🌍Привет안녕مرحباשלום"), + ) + ] + ) + + return server + + +@asynccontextmanager +async def unicode_session() -> AsyncIterator[ClientSession]: + """Yield an initialized ClientSession speaking streamable HTTP (SSE responses) to the + Unicode test server, entirely in process.""" + # SSE response mode, so Unicode rides the SSE event encoding rather than a plain JSON body. + session_manager = StreamableHTTPSessionManager(app=make_unicode_server(), json_response=False) + app = Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)]) + + async with ( + session_manager.run(), + # follow_redirects matches the SDK's own client factory; Starlette's Mount 307-redirects + # the bare /mcp path to /mcp/. + httpx.AsyncClient( + transport=StreamingASGITransport(app), base_url=BASE_URL, follow_redirects=True + ) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as ( + read_stream, + write_stream, + _get_session_id, + ), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() + yield session @pytest.mark.anyio -async def test_streamable_http_client_unicode_tool_call(running_unicode_server: str) -> None: +async def test_streamable_http_client_unicode_tool_call() -> None: """Test that Unicode text is correctly handled in tool calls via streamable HTTP.""" - base_url = running_unicode_server - endpoint_url = f"{base_url}/mcp" - - async with streamable_http_client(endpoint_url) as (read_stream, write_stream, _get_session_id): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - # Test 1: List tools (server→client Unicode in descriptions) - tools = await session.list_tools() - assert len(tools.tools) == 1 + async with unicode_session() as session: + # Test 1: List tools (server→client Unicode in descriptions) + tools = await session.list_tools() + assert len(tools.tools) == 1 - # Check Unicode in tool descriptions - echo_tool = tools.tools[0] - assert echo_tool.name == "echo_unicode" - assert echo_tool.description is not None - assert "🔤" in echo_tool.description - assert "👋" in echo_tool.description + # Check Unicode in tool descriptions + echo_tool = tools.tools[0] + assert echo_tool.name == "echo_unicode" + assert echo_tool.description is not None + assert "🔤" in echo_tool.description + assert "👋" in echo_tool.description - # Test 2: Send Unicode text in tool call (client→server→client) - for test_name, test_string in UNICODE_TEST_STRINGS.items(): - result = await session.call_tool("echo_unicode", arguments={"text": test_string}) + # Test 2: Send Unicode text in tool call (client→server→client) + for test_name, test_string in UNICODE_TEST_STRINGS.items(): + result = await session.call_tool("echo_unicode", arguments={"text": test_string}) - # Verify server correctly received and echoed back Unicode - assert len(result.content) == 1 - content = result.content[0] - assert content.type == "text" - assert f"Echo: {test_string}" == content.text, f"Failed for {test_name}" + # Verify server correctly received and echoed back Unicode + assert len(result.content) == 1 + content = result.content[0] + assert content.type == "text" + assert f"Echo: {test_string}" == content.text, f"Failed for {test_name}" @pytest.mark.anyio -async def test_streamable_http_client_unicode_prompts(running_unicode_server: str) -> None: +async def test_streamable_http_client_unicode_prompts() -> None: """Test that Unicode text is correctly handled in prompts via streamable HTTP.""" - base_url = running_unicode_server - endpoint_url = f"{base_url}/mcp" - - async with streamable_http_client(endpoint_url) as (read_stream, write_stream, _get_session_id): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - # Test 1: List prompts (server→client Unicode in descriptions) - prompts = await session.list_prompts() - assert len(prompts.prompts) == 1 - - prompt = prompts.prompts[0] - assert prompt.name == "unicode_prompt" - assert prompt.description is not None - assert "Слой хранилища, где располагаются" in prompt.description - - # Test 2: Get prompt with Unicode content (server→client) - result = await session.get_prompt("unicode_prompt", arguments={}) - assert len(result.messages) == 1 - - message = result.messages[0] - assert message.role == "user" - assert message.content.type == "text" - assert message.content.text == "Hello世界🌍Привет안녕مرحباשלום" + async with unicode_session() as session: + # Test 1: List prompts (server→client Unicode in descriptions) + prompts = await session.list_prompts() + assert len(prompts.prompts) == 1 + + prompt = prompts.prompts[0] + assert prompt.name == "unicode_prompt" + assert prompt.description is not None + assert "Слой хранилища, где располагаются" in prompt.description + + # Test 2: Get prompt with Unicode content (server→client) + result = await session.get_prompt("unicode_prompt", arguments={}) + assert len(result.messages) == 1 + + message = result.messages[0] + assert message.role == "user" + assert message.content.type == "text" + assert message.content.text == "Hello世界🌍Привет안녕مرحباשלום" diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 7604450f8..6b0a0d5b1 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,18 +1,17 @@ +"""Tests for the SSE client and server transports, driven entirely in process.""" + import json -import multiprocessing -import socket -import time -from collections.abc import AsyncGenerator, Generator +from collections.abc import AsyncGenerator, Iterable, Iterator from typing import Any from unittest.mock import AsyncMock, MagicMock, Mock, patch import anyio import httpx import pytest -import uvicorn from httpx_sse import ServerSentEvent from inline_snapshot import snapshot from pydantic import AnyUrl +from sse_starlette.sse import AppStatus from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response @@ -23,188 +22,168 @@ from mcp.client.session import ClientSession from mcp.client.sse import _extract_session_id_from_endpoint, sse_client from mcp.server import Server +from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared._httpx_utils import McpHttpClientFactory from mcp.shared.exceptions import McpError from mcp.types import ( + CallToolResult, EmptyResult, ErrorData, Implementation, InitializeResult, JSONRPCResponse, - ReadResourceResult, ServerCapabilities, TextContent, TextResourceContents, Tool, ) -from tests.test_helpers import wait_for_server +from tests.interaction.transports import StreamingASGITransport SERVER_NAME = "test_server_for_SSE" - -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: - return f"http://127.0.0.1:{server_port}" - - -# Test server implementation -class ServerTest(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) - - @self.read_resource() - async def handle_read_resource(uri: AnyUrl) -> str | bytes: - if uri.scheme == "foobar": - return f"Read {uri.host}" - elif uri.scheme == "slow": - # Simulate a slow resource - await anyio.sleep(2.0) - return f"Slow response from {uri.host}" - - raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found")) - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="test_tool", - description="A test tool", - inputSchema={"type": "object", "properties": {}}, - ) - ] - - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - return [TextContent(type="text", text=f"Called {name}")] - - -# Test fixtures -def make_server_app() -> Starlette: # pragma: no cover - """Create test Starlette app with SSE transport""" - # Configure security with allowed hosts/origins for testing - security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] +# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" + +# v1's HTTP server transports leak a handful of anyio memory streams on teardown when run in +# process; the old subprocess harness never observed them. The interaction suite registers the +# same two scoped filters globally from tests/interaction/conftest.py (see the comment there), +# but they only take effect when that package's conftest is loaded; these markers keep this file +# self-contained for isolated runs. The filters are scoped to anyio's MemoryObject*Stream leak +# signature so an unrelated leak still fails the suite. +pytestmark = [ + pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:pytest.PytestUnraisableExceptionWarning"), + pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:ResourceWarning"), +] + + +@pytest.fixture(autouse=True) +def _reset_sse_starlette_exit_event() -> Iterator[None]: + """Reset sse-starlette's module-global exit Event after each test. + + sse-starlette <3.0 (allowed by this branch's dependency floor; CI's lowest-direct leg + installs it) stores an `anyio.Event` on the `AppStatus` class the first time an + `EventSourceResponse` runs; that Event is bound to the test's event loop and breaks every + subsequent in-process SSE response. sse-starlette 3.x switched to a ContextVar and has no + such attribute. This mirrors the autouse fixture in tests/interaction/conftest.py, which + guards the interaction suite the same way. + """ + yield + if hasattr(AppStatus, "should_exit_event"): # pragma: no branch + # setattr keeps pyright happy: the locked sse-starlette 3.x has no such attribute. + setattr(AppStatus, "should_exit_event", None) # pragma: lax no cover + + +def in_process_client_factory(app: Starlette) -> McpHttpClientFactory: + """An httpx_client_factory for sse_client whose clients are served in process by `app`.""" + + def factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + # The SSE GET runs until it observes a disconnect, so the bridge must let the + # application drain on close rather than cancelling it. follow_redirects matches + # create_mcp_http_client, the factory this one stands in for. + return httpx.AsyncClient( + transport=StreamingASGITransport(app, cancel_on_close=False), + base_url=BASE_URL, + headers=headers, + timeout=timeout, + auth=auth, + follow_redirects=True, + ) + + return factory + + +def make_test_server() -> Server[object, Request]: + """A server whose read_resource handler answers foobar:// URIs and 404s everything else.""" + server: Server[object, Request] = Server(SERVER_NAME) + + @server.read_resource() + async def handle_read_resource(uri: AnyUrl) -> Iterable[ReadResourceContents]: + if uri.scheme == "foobar": + return [ReadResourceContents(content=f"Read {uri.host}", mime_type="text/plain")] + raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found")) + + return server + + +def make_app(server: Server[Any, Any]) -> Starlette: + """Mount `server` on a Starlette app exposing the SSE transport at /sse and /messages/.""" + # DNS-rebinding protection validates Host/Origin headers against a network attack that cannot + # exist for an in-process app; the transport security behaviour itself is pinned by + # tests/server/test_sse_security.py. + sse = SseServerTransport( + "/messages/", security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False) ) - sse = SseServerTransport("/messages/", security_settings=security_settings) - server = ServerTest() async def handle_sse(request: Request) -> Response: - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: - await server.run(streams[0], streams[1], server.create_initialization_options()) + async with sse.connect_sse(request.scope, request.receive, request._send) as (read_stream, write_stream): + await server.run(read_stream, write_stream, server.create_initialization_options()) return Response() - app = Starlette( + return Starlette( routes=[ Route("/sse", endpoint=handle_sse), Mount("/messages/", app=sse.handle_post_message), ] ) - return app - -def run_server(server_port: int) -> None: # pragma: no cover - app = make_server_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() +def make_server_app() -> Starlette: + return make_app(make_test_server()) - # Give server time to start - while not server.started: - print("waiting for server to start") - time.sleep(0.5) +@pytest.mark.anyio +async def test_raw_sse_connection() -> None: + """The SSE GET responds 200 with an event-stream content type, announcing the session + endpoint as its first event.""" + http_client = httpx.AsyncClient( + transport=StreamingASGITransport(make_server_app(), cancel_on_close=False), base_url=BASE_URL + ) -@pytest.fixture() -def server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() - - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) - - yield - - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("server process failed to terminate") - - -@pytest.fixture() -async def http_client(server: None, server_url: str) -> AsyncGenerator[httpx.AsyncClient, None]: - """Create test client""" - async with httpx.AsyncClient(base_url=server_url) as client: - yield client - + with anyio.fail_after(5): + async with http_client, http_client.stream("GET", "/sse") as response: + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" -# Tests -@pytest.mark.anyio -async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None: - """Test the SSE connection establishment simply with an HTTP client.""" - async with anyio.create_task_group(): - - async def connection_test() -> None: - async with http_client.stream("GET", "/sse") as response: - assert response.status_code == 200 - assert response.headers["content-type"] == "text/event-stream; charset=utf-8" - - line_number = 0 - async for line in response.aiter_lines(): # pragma: no branch - if line_number == 0: - assert line == "event: endpoint" - elif line_number == 1: - assert line.startswith("data: /messages/?session_id=") - else: - return - line_number += 1 - - # Add timeout to prevent test from hanging if it fails - with anyio.fail_after(3): - await connection_test() + lines = response.aiter_lines() + assert await anext(lines) == "event: endpoint" + assert (await anext(lines)).startswith("data: /messages/?session_id=") @pytest.mark.anyio -async def test_sse_client_basic_connection(server: None, server_url: str) -> None: - async with sse_client(server_url + "/sse") as streams: +async def test_sse_client_basic_connection() -> None: + """A client initializes against, and pings, a server over the SSE transport.""" + factory = in_process_client_factory(make_server_app()) + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory) as streams: async with ClientSession(*streams) as session: - # Test initialization result = await session.initialize() assert isinstance(result, InitializeResult) assert result.serverInfo.name == SERVER_NAME - # Test ping ping_result = await session.send_ping() assert isinstance(ping_result, EmptyResult) @pytest.mark.anyio -async def test_sse_client_on_session_created(server: None, server_url: str) -> None: - captured_session_id: str | None = None - - def on_session_created(session_id: str) -> None: - nonlocal captured_session_id - captured_session_id = session_id - - async with sse_client(server_url + "/sse", on_session_created=on_session_created) as streams: +async def test_sse_client_on_session_created() -> None: + """The session-created callback receives the new session ID before sse_client yields.""" + factory = in_process_client_factory(make_server_app()) + captured: list[str] = [] + + async with sse_client( + f"{BASE_URL}/sse", httpx_client_factory=factory, on_session_created=captured.append + ) as streams: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) - - assert captured_session_id is not None - assert len(captured_session_id) > 0 + # Callback fires when the endpoint event arrives, before sse_client yields. + assert len(captured) == 1 + assert len(captured[0]) > 0 @pytest.mark.parametrize( @@ -219,13 +198,14 @@ def on_session_created(session_id: str) -> None: ], ) def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | None) -> None: + """The session ID is read from the endpoint URL's sessionId/session_id query parameters.""" assert _extract_session_id_from_endpoint(endpoint_url) == expected @pytest.mark.anyio -async def test_sse_client_on_session_created_not_called_when_no_session_id( - server: None, server_url: str, monkeypatch: pytest.MonkeyPatch -) -> None: +async def test_sse_client_on_session_created_not_called_when_no_session_id(monkeypatch: pytest.MonkeyPatch) -> None: + """No session-created callback fires when the endpoint URL carries no session ID.""" + factory = in_process_client_factory(make_server_app()) callback_mock = Mock() def mock_extract(url: str) -> None: @@ -233,17 +213,19 @@ def mock_extract(url: str) -> None: monkeypatch.setattr(mcp.client.sse, "_extract_session_id_from_endpoint", mock_extract) - async with sse_client(server_url + "/sse", on_session_created=callback_mock) as streams: + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory, on_session_created=callback_mock) as streams: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) - - callback_mock.assert_not_called() + # Callback would have fired by now (endpoint event arrives before + # sse_client yields); if it hasn't, it won't. + callback_mock.assert_not_called() @pytest.fixture -async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: - async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: +async def initialized_sse_client_session() -> AsyncGenerator[ClientSession, None]: + factory = in_process_client_factory(make_server_app()) + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory) as streams: async with ClientSession(*streams) as session: await session.initialize() yield session @@ -253,6 +235,7 @@ async def initialized_sse_client_session(server: None, server_url: str) -> Async async def test_sse_client_happy_request_and_response( initialized_sse_client_session: ClientSession, ) -> None: + """A resource read round-trips its arguments and the handler's content over SSE.""" session = initialized_sse_client_session response = await session.read_resource(uri=AnyUrl("foobar://should-work")) assert len(response.contents) == 1 @@ -264,226 +247,123 @@ async def test_sse_client_happy_request_and_response( async def test_sse_client_exception_handling( initialized_sse_client_session: ClientSession, ) -> None: + """A server-side McpError reaches the client with its message intact.""" session = initialized_sse_client_session with pytest.raises(McpError, match="OOPS! no resource with that URI was found"): await session.read_resource(uri=AnyUrl("xxx://will-not-work")) @pytest.mark.anyio -@pytest.mark.skip("this test highlights a possible bug in SSE read timeout exception handling") -async def test_sse_client_timeout( # pragma: no cover - initialized_sse_client_session: ClientSession, -) -> None: - session = initialized_sse_client_session - - # sanity check that normal, fast responses are working - response = await session.read_resource(uri=AnyUrl("foobar://1")) - assert isinstance(response, ReadResourceResult) - - with anyio.move_on_after(3): - with pytest.raises(McpError, match="Read timed out"): - response = await session.read_resource(uri=AnyUrl("slow://2")) - # we should receive an error here - return - - pytest.fail("the client should have timed out and returned an error already") - - -def run_mounted_server(server_port: int) -> None: # pragma: no cover - app = make_server_app() - main_app = Starlette(routes=[Mount("/mounted_app", app=app)]) - server = uvicorn.Server(config=uvicorn.Config(app=main_app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - - # Give server time to start - while not server.started: - print("waiting for server to start") - time.sleep(0.5) - +async def test_sse_client_basic_connection_mounted_app() -> None: + """The SSE transport works unchanged when its app is mounted under a sub-path.""" + main_app = Starlette(routes=[Mount("/mounted_app", app=make_server_app())]) + factory = in_process_client_factory(main_app) -@pytest.fixture() -def mounted_server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() - - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) - - yield - - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("server process failed to terminate") - - -@pytest.mark.anyio -async def test_sse_client_basic_connection_mounted_app(mounted_server: None, server_url: str) -> None: - async with sse_client(server_url + "/mounted_app/sse") as streams: + async with sse_client(f"{BASE_URL}/mounted_app/sse", httpx_client_factory=factory) as streams: async with ClientSession(*streams) as session: - # Test initialization result = await session.initialize() assert isinstance(result, InitializeResult) assert result.serverInfo.name == SERVER_NAME - # Test ping ping_result = await session.send_ping() assert isinstance(ping_result, EmptyResult) -# Test server with request context that returns headers in the response -class RequestContextServer(Server[object, Request]): # pragma: no cover - def __init__(self): - super().__init__("request_context_server") - - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - headers_info = {} - context = self.request_context - if context.request: - headers_info = dict(context.request.headers) - - if name == "echo_headers": - return [TextContent(type="text", text=json.dumps(headers_info))] - elif name == "echo_context": - context_data = { - "request_id": args.get("request_id"), - "headers": headers_info, - } - return [TextContent(type="text", text=json.dumps(context_data))] - - return [TextContent(type="text", text=f"Called {name}")] - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="echo_headers", - description="Echoes request headers", - inputSchema={"type": "object", "properties": {}}, - ), - Tool( - name="echo_context", - description="Echoes request context", - inputSchema={ - "type": "object", - "properties": {"request_id": {"type": "string"}}, - "required": ["request_id"], - }, - ), - ] - - -def run_context_server(server_port: int) -> None: # pragma: no cover - """Run a server that captures request context""" - # Configure security with allowed hosts/origins for testing - security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] - ) - sse = SseServerTransport("/messages/", security_settings=security_settings) - context_server = RequestContextServer() - - async def handle_sse(request: Request) -> Response: - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: - await context_server.run(streams[0], streams[1], context_server.create_initialization_options()) - return Response() - - app = Starlette( - routes=[ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), +def make_context_server() -> Server[object, Request]: + """A server whose tools echo back the request headers seen via the request context.""" + server: Server[object, Request] = Server("request_context_server") + + @server.call_tool() + async def handle_call_tool(name: str, args: dict[str, Any]) -> CallToolResult: + assert name in ("echo_headers", "echo_context") + ctx = server.request_context + assert ctx.request is not None + headers_info = dict(ctx.request.headers) + + if name == "echo_headers": + return CallToolResult(content=[TextContent(type="text", text=json.dumps(headers_info))]) + + context_data = { + "request_id": args.get("request_id"), + "headers": headers_info, + } + return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) + + @server.list_tools() + async def handle_list_tools() -> list[Tool]: + return [ + Tool( + name="echo_headers", + description="Echoes request headers", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="echo_context", + description="Echoes request context", + inputSchema={ + "type": "object", + "properties": {"request_id": {"type": "string"}}, + "required": ["request_id"], + }, + ), ] - ) - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting context server on {server_port}") - server.run() + return server -@pytest.fixture() -def context_server(server_port: int) -> Generator[None, None, None]: - """Fixture that provides a server with request context capture""" - proc = multiprocessing.Process(target=run_context_server, kwargs={"server_port": server_port}, daemon=True) - print("starting context server process") - proc.start() - - # Wait for server to be running - print("waiting for context server to start") - wait_for_server(server_port) - - yield - - print("killing context server") - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("context server process failed to terminate") +def make_context_server_app() -> Starlette: + return make_app(make_context_server()) @pytest.mark.anyio -async def test_request_context_propagation(context_server: None, server_url: str) -> None: - """Test that request context is properly propagated through SSE transport.""" - # Test with custom headers +async def test_request_context_propagation() -> None: + """Custom HTTP headers on the SSE connection are visible to server handlers via the request context.""" + factory = in_process_client_factory(make_context_server_app()) + custom_headers = { "Authorization": "Bearer test-token", "X-Custom-Header": "test-value", "X-Trace-Id": "trace-123", } - async with sse_client(server_url + "/sse", headers=custom_headers) as ( - read_stream, - write_stream, - ): - async with ClientSession(read_stream, write_stream) as session: - # Initialize the session + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory, headers=custom_headers) as streams: + async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) - # Call the tool that echoes headers back tool_result = await session.call_tool("echo_headers", {}) - # Parse the JSON response - assert len(tool_result.content) == 1 - headers_data = json.loads(tool_result.content[0].text if tool_result.content[0].type == "text" else "{}") + content = tool_result.content[0] + assert isinstance(content, TextContent) + headers_data = json.loads(content.text) - # Verify headers were propagated assert headers_data.get("authorization") == "Bearer test-token" assert headers_data.get("x-custom-header") == "test-value" assert headers_data.get("x-trace-id") == "trace-123" @pytest.mark.anyio -async def test_request_context_isolation(context_server: None, server_url: str) -> None: - """Test that request contexts are isolated between different SSE clients.""" +async def test_request_context_isolation() -> None: + """Each SSE connection's handlers see only that connection's request headers.""" + factory = in_process_client_factory(make_context_server_app()) contexts: list[dict[str, Any]] = [] - # Create multiple clients with different headers + # Connect three clients in turn, each with its own headers. for i in range(3): headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"} - async with sse_client(server_url + "/sse", headers=headers) as ( - read_stream, - write_stream, - ): - async with ClientSession(read_stream, write_stream) as session: + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory, headers=headers) as streams: + async with ClientSession(*streams) as session: await session.initialize() - # Call the tool that echoes context tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"}) assert len(tool_result.content) == 1 - context_data = json.loads( - tool_result.content[0].text if tool_result.content[0].type == "text" else "{}" - ) - contexts.append(context_data) + content = tool_result.content[0] + assert isinstance(content, TextContent) + contexts.append(json.loads(content.text)) - # Verify each request had its own context assert len(contexts) == 3 for i, ctx in enumerate(contexts): assert ctx["request_id"] == f"request-{i}" @@ -491,7 +371,7 @@ async def test_request_context_isolation(context_server: None, server_url: str) assert ctx["headers"].get("x-custom-value") == f"value-{i}" -def test_sse_message_id_coercion(): +def test_sse_message_id_coercion() -> None: """Previously, the `RequestId` would coerce a string that looked like an integer into an integer. See for more details. @@ -525,7 +405,7 @@ def test_sse_message_id_coercion(): ("/messages/#fragment", ValueError), ], ) -def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result: str | type[Exception]): +def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result: str | type[Exception]) -> None: """Test that SseServerTransport properly validates and normalizes endpoints.""" if isinstance(expected_result, type): # Test invalid endpoints that should raise an exception @@ -602,3 +482,33 @@ async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]: assert not isinstance(msg, Exception) assert isinstance(msg.message.root, types.JSONRPCResponse) assert msg.message.root.id == 1 + + +@pytest.mark.anyio +async def test_sse_session_cleanup_on_disconnect() -> None: + """Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1227 + + When a client disconnects, the server should remove the session from + _read_stream_writers. Without this cleanup, stale sessions accumulate and + POST requests to disconnected sessions return 202 Accepted followed by a + ClosedResourceError when the server tries to write to the dead stream. + """ + factory = in_process_client_factory(make_server_app()) + captured: list[str] = [] + + # Connect a client session, then disconnect + async with sse_client( + f"{BASE_URL}/sse", httpx_client_factory=factory, on_session_created=captured.append + ) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + + # After disconnect, POST to the stale session should return 404 + # (not 202 as it did before the fix) + async with factory() as client: + response = await client.post( + f"/messages/?session_id={captured[0]}", + json={"jsonrpc": "2.0", "method": "ping", "id": 99}, + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 404 From a94f3d5827745c85c37fded476631e4e6a18d93d Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 13:03:15 +0000 Subject: [PATCH 07/17] Guard in-process HTTP tests against sse-starlette's global exit event sse-starlette <3.0 stores its exit Event on the AppStatus class the first time an EventSourceResponse runs; the Event is bound to that test's event loop and breaks every later in-process SSE response on the same worker. test_http_unicode.py serves all its requests as EventSourceResponses (json_response=False) but had no reset fixture, so on a sse-starlette<3.0 install (CI's lowest-direct legs) it could poison the worker for any later SSE-based test. - Copy the autouse AppStatus reset fixture into test_http_unicode.py. - Reset on both sides of the yield, here and in test_sse.py, so each module also survives a stale Event left behind by an earlier test. - Correct the filterwarnings comments in both files: the item-scoped markers cannot cover the GC flush at session cleanup, so isolated runs without xdist (-n 0) still exit nonzero after all tests pass. --- tests/client/test_http_unicode.py | 35 +++++++++++++++++++++++++++---- tests/shared/test_sse.py | 21 ++++++++++++------- 2 files changed, 45 insertions(+), 11 deletions(-) diff --git a/tests/client/test_http_unicode.py b/tests/client/test_http_unicode.py index dccba79de..21097ce65 100644 --- a/tests/client/test_http_unicode.py +++ b/tests/client/test_http_unicode.py @@ -5,12 +5,13 @@ (server→client and client→server) using the streamable HTTP transport. """ -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Iterator from contextlib import asynccontextmanager from typing import Any import httpx import pytest +from sse_starlette.sse import AppStatus from starlette.applications import Starlette from starlette.routing import Mount @@ -28,14 +29,40 @@ # v1's streamable-HTTP server transport leaks a handful of anyio memory streams on teardown when # run in process; the old subprocess harness never observed them. The interaction suite registers # the same two scoped filters globally from tests/interaction/conftest.py (see the comment there), -# but they only take effect when that package's conftest is loaded; these markers keep this file -# self-contained for isolated runs. The filters are scoped to anyio's MemoryObject*Stream leak -# signature so an unrelated leak still fails the suite. +# but they only take effect when that package's conftest is loaded; these markers keep the tests +# themselves passing in isolated runs. Markers are item-scoped, so they cannot cover the GC +# flush at session cleanup: an isolated run without xdist (`-n 0`) still exits nonzero after all +# tests pass. The default xdist runs (addopts has `-n auto`) are unaffected, as are full-suite +# runs, where the interaction conftest's ini-level filters apply. The filters are scoped to +# anyio's MemoryObject*Stream leak signature so an unrelated leak still fails the suite. pytestmark = [ pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:pytest.PytestUnraisableExceptionWarning"), pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:ResourceWarning"), ] + +@pytest.fixture(autouse=True) +def _reset_sse_starlette_exit_event() -> Iterator[None]: + """Reset sse-starlette's module-global exit Event around each test. + + sse-starlette <3.0 (allowed by this branch's dependency floor; CI's lowest-direct leg + installs it) stores an `anyio.Event` on the `AppStatus` class the first time an + `EventSourceResponse` runs; that Event is bound to the test's event loop and breaks every + subsequent in-process SSE response (and `json_response=False` below means every request + in this module is served as one). sse-starlette 3.x switched to a ContextVar and has no + such attribute. Resetting on both sides of the test keeps this module immune to a stale + Event left behind by an earlier test on the same worker as well as cleaning up after its + own. This mirrors the autouse fixtures in tests/shared/test_sse.py and + tests/interaction/conftest.py. + """ + if hasattr(AppStatus, "should_exit_event"): # pragma: no branch + # setattr keeps pyright happy: the locked sse-starlette 3.x has no such attribute. + setattr(AppStatus, "should_exit_event", None) # pragma: lax no cover + yield + if hasattr(AppStatus, "should_exit_event"): # pragma: no branch + setattr(AppStatus, "should_exit_event", None) # pragma: lax no cover + + # Test constants with various Unicode characters UNICODE_TEST_STRINGS = { "cyrillic": "Слой хранилища, где располагаются", diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 6b0a0d5b1..afe93a7a4 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -49,9 +49,12 @@ # v1's HTTP server transports leak a handful of anyio memory streams on teardown when run in # process; the old subprocess harness never observed them. The interaction suite registers the # same two scoped filters globally from tests/interaction/conftest.py (see the comment there), -# but they only take effect when that package's conftest is loaded; these markers keep this file -# self-contained for isolated runs. The filters are scoped to anyio's MemoryObject*Stream leak -# signature so an unrelated leak still fails the suite. +# but they only take effect when that package's conftest is loaded; these markers keep the tests +# themselves passing in isolated runs. Markers are item-scoped, so they cannot cover the GC +# flush at session cleanup: an isolated run without xdist (`-n 0`) still exits nonzero after all +# tests pass. The default xdist runs (addopts has `-n auto`) are unaffected, as are full-suite +# runs, where the interaction conftest's ini-level filters apply. The filters are scoped to +# anyio's MemoryObject*Stream leak signature so an unrelated leak still fails the suite. pytestmark = [ pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:pytest.PytestUnraisableExceptionWarning"), pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:ResourceWarning"), @@ -60,19 +63,23 @@ @pytest.fixture(autouse=True) def _reset_sse_starlette_exit_event() -> Iterator[None]: - """Reset sse-starlette's module-global exit Event after each test. + """Reset sse-starlette's module-global exit Event around each test. sse-starlette <3.0 (allowed by this branch's dependency floor; CI's lowest-direct leg installs it) stores an `anyio.Event` on the `AppStatus` class the first time an `EventSourceResponse` runs; that Event is bound to the test's event loop and breaks every subsequent in-process SSE response. sse-starlette 3.x switched to a ContextVar and has no - such attribute. This mirrors the autouse fixture in tests/interaction/conftest.py, which - guards the interaction suite the same way. + such attribute. Resetting on both sides of the test keeps this module immune to a stale + Event left behind by an earlier test on the same worker as well as cleaning up after its + own. This mirrors the autouse fixture in tests/interaction/conftest.py, which guards the + interaction suite the same way. """ - yield if hasattr(AppStatus, "should_exit_event"): # pragma: no branch # setattr keeps pyright happy: the locked sse-starlette 3.x has no such attribute. setattr(AppStatus, "should_exit_event", None) # pragma: lax no cover + yield + if hasattr(AppStatus, "should_exit_event"): # pragma: no branch + setattr(AppStatus, "should_exit_event", None) # pragma: lax no cover def in_process_client_factory(app: Starlette) -> McpHttpClientFactory: From 43cf9ec76242670fae1bb23cb1999540467ed3f1 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 12:50:29 +0000 Subject: [PATCH 08/17] Run StreamableHTTP transport tests in process instead of over sockets (#2767) Tests-only backport to v1.x; adapted from main commit 19fe9fa. --- tests/shared/test_streamable_http.py | 2713 ++++++++++++-------------- 1 file changed, 1237 insertions(+), 1476 deletions(-) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 731dd20dd..dd0a413af 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1,14 +1,14 @@ -""" -Tests for the StreamableHTTP server and client transport. +"""Tests for the StreamableHTTP server and client transport. -Contains tests for both server and client sides of the StreamableHTTP transport. +Contains tests for both server and client sides of the StreamableHTTP transport, driven +entirely in process. """ import json -import multiprocessing -import socket import time -from collections.abc import Generator +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass, field from datetime import timedelta from typing import Any from unittest.mock import MagicMock @@ -16,8 +16,6 @@ import anyio import httpx import pytest -import requests -import uvicorn from httpx_sse import ServerSentEvent from pydantic import AnyUrl from starlette.applications import Starlette @@ -45,7 +43,6 @@ ) from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage @@ -58,11 +55,23 @@ TextResourceContents, Tool, ) -from tests.test_helpers import wait_for_server +from tests.interaction.transports._bridge import StreamingASGITransport + +# v1's streamable-HTTP server transport leaks a handful of anyio memory streams on teardown +# (e.g. `_handle_get_request` only closes `sse_stream_reader` on the exception path; the +# session manager's per-session task-group cancel can race the per-request cleanup). The old +# socket-based version of this file ran the transport in a separate process and so never +# observed these `__del__`-time ResourceWarnings; running in-process via the streaming bridge +# does. The fixes live in `src/` on `main` and are out of scope for this tests-only change. +# The filters are scoped to anyio's `MemoryObject*Stream` leak signature so an unrelated leak +# still fails the suite; tests/interaction/conftest.py applies the same pair for the same reason. +pytestmark = [ + pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:pytest.PytestUnraisableExceptionWarning"), + pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:ResourceWarning"), +] # Test constants SERVER_NAME = "test_streamable_http_server" -TEST_SESSION_ID = "test-session-id-12345" INIT_REQUEST = { "jsonrpc": "2.0", "method": "initialize", @@ -74,16 +83,19 @@ "id": "init-1", } +# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" + # Helper functions -def extract_protocol_version_from_sse(response: requests.Response) -> str: # pragma: no cover +def extract_protocol_version_from_sse(response: httpx.Response) -> str: """Extract the negotiated protocol version from an SSE initialization response.""" assert response.headers.get("Content-Type") == "text/event-stream" for line in response.text.splitlines(): if line.startswith("data: "): init_data = json.loads(line[6:]) return init_data["result"]["protocolVersion"] - raise ValueError("Could not extract protocol version from SSE response") + raise ValueError("Could not extract protocol version from SSE response") # pragma: no cover # Simple in-memory event store for testing @@ -94,412 +106,263 @@ def __init__(self): self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage | None]] = [] self._event_id_counter = 0 - async def store_event( # pragma: no cover - self, stream_id: StreamId, message: types.JSONRPCMessage | None - ) -> EventId: + async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage | None) -> EventId: """Store an event and return its ID.""" self._event_id_counter += 1 event_id = str(self._event_id_counter) self._events.append((stream_id, event_id, message)) return event_id - async def replay_events_after( # pragma: no cover + async def replay_events_after( self, last_event_id: EventId, send_callback: EventCallback, ) -> StreamId | None: """Replay events after the specified ID.""" - # Find the stream ID of the last event - target_stream_id = None - for stream_id, event_id, _ in self._events: - if event_id == last_event_id: - target_stream_id = stream_id - break - - if target_stream_id is None: - # If event ID not found, return None - return None + # Find the stream ID of the last event; clients always resume from a stored event. + target_stream_id = next(stream_id for stream_id, event_id, _ in self._events if event_id == last_event_id) # Convert last_event_id to int for comparison last_event_id_int = int(last_event_id) - # Replay only events from the same stream with ID > last_event_id + # Replay only events from the same stream with ID > last_event_id, skipping priming + # events (None message). for stream_id, event_id, message in self._events: - if stream_id == target_stream_id and int(event_id) > last_event_id_int: - # Skip priming events (None message) - if message is not None: - await send_callback(EventMessage(message, event_id)) + if stream_id == target_stream_id and message is not None and int(event_id) > last_event_id_int: + await send_callback(EventMessage(message, event_id)) return target_stream_id -# Test server implementation that follows MCP protocol -class ServerTest(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) - self._lock = None # Will be initialized in async context - - @self.read_resource() - async def handle_read_resource(uri: AnyUrl) -> str | bytes: - if uri.scheme == "foobar": - return f"Read {uri.host}" - elif uri.scheme == "slow": - # Simulate a slow resource - await anyio.sleep(2.0) - return f"Slow response from {uri.host}" - - raise ValueError(f"Unknown resource: {uri}") - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="test_tool", - description="A test tool", - inputSchema={"type": "object", "properties": {}}, - ), - Tool( - name="test_tool_with_standalone_notification", - description="A test tool that sends a notification", - inputSchema={"type": "object", "properties": {}}, - ), - Tool( - name="long_running_with_checkpoints", - description="A long-running tool that sends periodic notifications", - inputSchema={"type": "object", "properties": {}}, - ), - Tool( - name="test_sampling_tool", - description="A tool that triggers server-side sampling", - inputSchema={"type": "object", "properties": {}}, - ), - Tool( - name="wait_for_lock_with_notification", - description="A tool that sends a notification and waits for lock", - inputSchema={"type": "object", "properties": {}}, - ), - Tool( - name="release_lock", - description="A tool that releases the lock", - inputSchema={"type": "object", "properties": {}}, - ), - Tool( - name="tool_with_stream_close", - description="A tool that closes SSE stream mid-operation", - inputSchema={"type": "object", "properties": {}}, - ), - Tool( - name="tool_with_multiple_notifications_and_close", - description="Tool that sends notification1, closes stream, sends notification2, notification3", - inputSchema={"type": "object", "properties": {}}, - ), - Tool( - name="tool_with_multiple_stream_closes", - description="Tool that closes SSE stream multiple times during execution", - inputSchema={ - "type": "object", - "properties": { - "checkpoints": {"type": "integer", "default": 3}, - "sleep_time": {"type": "number", "default": 0.2}, - }, - }, - ), - Tool( - name="tool_with_standalone_stream_close", - description="Tool that closes standalone GET stream mid-operation", - inputSchema={"type": "object", "properties": {}}, - ), - ] +@dataclass +class ServerState: + lock: anyio.Event = field(default_factory=anyio.Event) - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - ctx = self.request_context - # When the tool is called, send a notification to test GET stream - if name == "test_tool_with_standalone_notification": - await ctx.session.send_resource_updated(uri=AnyUrl("http://test_resource")) - return [TextContent(type="text", text=f"Called {name}")] +@asynccontextmanager +async def _server_lifespan(_server: Server[ServerState, Request]) -> AsyncIterator[ServerState]: + yield ServerState() - elif name == "long_running_with_checkpoints": - # Send notifications that are part of the response stream - # This simulates a long-running tool that sends logs - await ctx.session.send_log_message( - level="info", - data="Tool started", - logger="tool", - related_request_id=ctx.request_id, # need for stream association - ) +def _create_server() -> Server[ServerState, Request]: + server: Server[ServerState, Request] = Server(SERVER_NAME, lifespan=_server_lifespan) - await anyio.sleep(0.1) + @server.read_resource() + async def handle_read_resource(uri: AnyUrl) -> str | bytes: + if uri.scheme == "foobar": + return f"Read {uri.host}" + raise ValueError(f"Unknown resource: {uri}") - await ctx.session.send_log_message( - level="info", - data="Tool is almost done", - logger="tool", - related_request_id=ctx.request_id, - ) + @server.list_tools() + async def handle_list_tools() -> list[Tool]: + return [ + Tool( + name="test_tool", + description="A test tool", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="test_tool_with_standalone_notification", + description="A test tool that sends a notification", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="test_sampling_tool", + description="A tool that triggers server-side sampling", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="wait_for_lock_with_notification", + description="A tool that sends a notification and waits for lock", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="release_lock", + description="A tool that releases the lock", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="tool_with_stream_close", + description="A tool that closes SSE stream mid-operation", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="tool_with_multiple_notifications_and_close", + description="Tool that sends notification1, closes stream, sends notification2, notification3", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="tool_with_standalone_stream_close", + description="Tool that closes standalone GET stream mid-operation", + inputSchema={"type": "object", "properties": {}}, + ), + ] - return [TextContent(type="text", text="Completed!")] + @server.call_tool() + async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: + ctx = server.request_context - elif name == "test_sampling_tool": - # Test sampling by requesting the client to sample a message - sampling_result = await ctx.session.create_message( - messages=[ - types.SamplingMessage( - role="user", - content=types.TextContent(type="text", text="Server needs client sampling"), - ) - ], - max_tokens=100, - related_request_id=ctx.request_id, - ) + # When the tool is called, send a notification to test GET stream + if name == "test_tool_with_standalone_notification": + await ctx.session.send_resource_updated(uri=AnyUrl("http://test_resource")) + return [TextContent(type="text", text=f"Called {name}")] - # Return the sampling result in the tool response - # Since we're not passing tools param, result.content is single content - if sampling_result.content.type == "text": - response = sampling_result.content.text - else: - response = str(sampling_result.content) - return [ - TextContent( - type="text", - text=f"Response from sampling: {response}", + elif name == "test_sampling_tool": + sampling_result = await ctx.session.create_message( + messages=[ + types.SamplingMessage( + role="user", + content=types.TextContent(type="text", text="Server needs client sampling"), ) - ] - - elif name == "wait_for_lock_with_notification": - # Initialize lock if not already done - if self._lock is None: - self._lock = anyio.Event() - - # First send a notification - await ctx.session.send_log_message( - level="info", - data="First notification before lock", - logger="lock_tool", - related_request_id=ctx.request_id, - ) - - # Now wait for the lock to be released - await self._lock.wait() + ], + max_tokens=100, + related_request_id=ctx.request_id, + ) - # Send second notification after lock is released - await ctx.session.send_log_message( - level="info", - data="Second notification after lock", - logger="lock_tool", - related_request_id=ctx.request_id, + assert sampling_result.content.type == "text" + return [ + TextContent( + type="text", + text=f"Response from sampling: {sampling_result.content.text}", ) + ] - return [TextContent(type="text", text="Completed")] - - elif name == "release_lock": - assert self._lock is not None, "Lock must be initialized before releasing" - - # Release the lock - self._lock.set() - return [TextContent(type="text", text="Lock released")] + elif name == "wait_for_lock_with_notification": + await ctx.session.send_log_message( + level="info", + data="First notification before lock", + logger="lock_tool", + related_request_id=ctx.request_id, + ) - elif name == "tool_with_stream_close": - # Send notification before closing - await ctx.session.send_log_message( - level="info", - data="Before close", - logger="stream_close_tool", - related_request_id=ctx.request_id, - ) - # Close SSE stream (triggers client reconnect) - assert ctx.close_sse_stream is not None - await ctx.close_sse_stream() - # Continue processing (events stored in event_store) - await anyio.sleep(0.1) - await ctx.session.send_log_message( - level="info", - data="After close", - logger="stream_close_tool", - related_request_id=ctx.request_id, - ) - return [TextContent(type="text", text="Done")] - - elif name == "tool_with_multiple_notifications_and_close": - # Send notification1 - await ctx.session.send_log_message( - level="info", - data="notification1", - logger="multi_notif_tool", - related_request_id=ctx.request_id, - ) - # Close SSE stream - assert ctx.close_sse_stream is not None - await ctx.close_sse_stream() - # Send notification2, notification3 (stored in event_store) - await anyio.sleep(0.1) - await ctx.session.send_log_message( - level="info", - data="notification2", - logger="multi_notif_tool", - related_request_id=ctx.request_id, - ) - await ctx.session.send_log_message( - level="info", - data="notification3", - logger="multi_notif_tool", - related_request_id=ctx.request_id, - ) - return [TextContent(type="text", text="All notifications sent")] - - elif name == "tool_with_multiple_stream_closes": - num_checkpoints = args.get("checkpoints", 3) - sleep_time = args.get("sleep_time", 0.2) - - for i in range(num_checkpoints): - await ctx.session.send_log_message( - level="info", - data=f"checkpoint_{i}", - logger="multi_close_tool", - related_request_id=ctx.request_id, - ) + await ctx.lifespan_context.lock.wait() - if ctx.close_sse_stream: - await ctx.close_sse_stream() + await ctx.session.send_log_message( + level="info", + data="Second notification after lock", + logger="lock_tool", + related_request_id=ctx.request_id, + ) - await anyio.sleep(sleep_time) + return [TextContent(type="text", text="Completed")] - return [TextContent(type="text", text=f"Completed {num_checkpoints} checkpoints")] + elif name == "release_lock": + ctx.lifespan_context.lock.set() + return [TextContent(type="text", text="Lock released")] - elif name == "tool_with_standalone_stream_close": - # Test for GET stream reconnection - # 1. Send unsolicited notification via GET stream (no related_request_id) - await ctx.session.send_resource_updated(uri=AnyUrl("http://notification_1")) + elif name == "tool_with_stream_close": + await ctx.session.send_log_message( + level="info", + data="Before close", + logger="stream_close_tool", + related_request_id=ctx.request_id, + ) + assert ctx.close_sse_stream is not None + await ctx.close_sse_stream() + await anyio.sleep(0.1) + await ctx.session.send_log_message( + level="info", + data="After close", + logger="stream_close_tool", + related_request_id=ctx.request_id, + ) + return [TextContent(type="text", text="Done")] + + elif name == "tool_with_multiple_notifications_and_close": + await ctx.session.send_log_message( + level="info", + data="notification1", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + assert ctx.close_sse_stream is not None + await ctx.close_sse_stream() + await anyio.sleep(0.1) + await ctx.session.send_log_message( + level="info", + data="notification2", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + await ctx.session.send_log_message( + level="info", + data="notification3", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + return [TextContent(type="text", text="All notifications sent")] - # Small delay to ensure notification is flushed before closing - await anyio.sleep(0.1) + elif name == "tool_with_standalone_stream_close": + await ctx.session.send_resource_updated(uri=AnyUrl("http://notification_1")) + await anyio.sleep(0.1) - # 2. Close the standalone GET stream - if ctx.close_standalone_sse_stream: - await ctx.close_standalone_sse_stream() + assert ctx.close_standalone_sse_stream is not None + await ctx.close_standalone_sse_stream() - # 3. Wait for client to reconnect (uses retry_interval from server, default 1000ms) - await anyio.sleep(1.5) + await anyio.sleep(1.5) + await ctx.session.send_resource_updated(uri=AnyUrl("http://notification_2")) - # 4. Send another notification on the new GET stream connection - await ctx.session.send_resource_updated(uri=AnyUrl("http://notification_2")) + return [TextContent(type="text", text="Standalone stream close test done")] - return [TextContent(type="text", text="Standalone stream close test done")] + return [TextContent(type="text", text=f"Called {name}")] - return [TextContent(type="text", text=f"Called {name}")] + return server -def create_app( +@asynccontextmanager +async def running_app( is_json_response_enabled: bool = False, event_store: EventStore | None = None, retry_interval: int | None = None, -) -> Starlette: # pragma: no cover - """Create a Starlette application for testing using the session manager. + server: Server[Any, Request] | None = None, +) -> AsyncIterator[Starlette]: + """Serve the test server's streamable HTTP app in process for the duration. Args: is_json_response_enabled: If True, use JSON responses instead of SSE streams. event_store: Optional event store for testing resumability. retry_interval: Retry interval in milliseconds for SSE polling. + server: Server to mount; defaults to the file's shared test server. """ - # Create server instance - server = ServerTest() - - # Create the session manager - security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] - ) + # DNS-rebinding protection validates Host/Origin headers against a network attack that cannot + # exist for an in-process app; the protection itself is pinned by + # tests/server/test_streamable_http_security.py. session_manager = StreamableHTTPSessionManager( - app=server, + app=server if server is not None else _create_server(), event_store=event_store, json_response=is_json_response_enabled, - security_settings=security_settings, + security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False), retry_interval=retry_interval, ) + app = Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)]) + async with session_manager.run(): + yield app - # Create an ASGI application that uses the session manager - app = Starlette( - debug=True, - routes=[ - Mount("/mcp", app=session_manager.handle_request), - ], - lifespan=lambda app: session_manager.run(), - ) - - return app +def make_client(app: Starlette, headers: dict[str, str] | None = None) -> httpx.AsyncClient: + """An httpx client served in process by `app`, with create_mcp_http_client's redirect default. -def run_server( - port: int, - is_json_response_enabled: bool = False, - event_store: EventStore | None = None, - retry_interval: int | None = None, -) -> None: # pragma: no cover - """Run the test server. - - Args: - port: Port to listen on. - is_json_response_enabled: If True, use JSON responses instead of SSE streams. - event_store: Optional event store for testing resumability. - retry_interval: Retry interval in milliseconds for SSE polling. + (Starlette's Mount 307-redirects the bare /mcp path to /mcp/, which the SDK's own client + factory follows.) """ - - app = create_app(is_json_response_enabled, event_store, retry_interval) - # Configure server - config = uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="info", - limit_concurrency=10, - timeout_keep_alive=5, - access_log=False, + return httpx.AsyncClient( + transport=StreamingASGITransport(app), base_url=BASE_URL, headers=headers, follow_redirects=True ) - # Start the server - server = uvicorn.Server(config=config) - - # This is important to catch exceptions and prevent test hangs - try: - server.run() - except Exception: - import traceback - - traceback.print_exc() - - -# Test fixtures - using same approach as SSE tests -@pytest.fixture -def basic_server_port() -> int: - """Find an available port for the basic server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - +# Test fixtures @pytest.fixture -def json_server_port() -> int: - """Find an available port for the JSON response server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] +async def basic_app() -> AsyncIterator[Starlette]: + """The test server's app with SSE response mode.""" + async with running_app() as app: + yield app @pytest.fixture -def basic_server(basic_server_port: int) -> Generator[None, None, None]: - """Start a basic server.""" - proc = multiprocessing.Process(target=run_server, kwargs={"port": basic_server_port}, daemon=True) - proc.start() - - # Wait for server to be running - wait_for_server(basic_server_port) - - yield - - # Clean up - proc.kill() - proc.join(timeout=2) +async def json_app() -> AsyncIterator[Starlette]: + """The test server's app with JSON response mode.""" + async with running_app(is_json_response_enabled=True) as app: + yield app @pytest.fixture @@ -509,160 +372,138 @@ def event_store() -> SimpleEventStore: @pytest.fixture -def event_server_port() -> int: - """Find an available port for the event store server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def event_server( - event_server_port: int, event_store: SimpleEventStore -) -> Generator[tuple[SimpleEventStore, str], None, None]: - """Start a server with event store and retry_interval enabled.""" - proc = multiprocessing.Process( - target=run_server, - kwargs={"port": event_server_port, "event_store": event_store, "retry_interval": 500}, - daemon=True, - ) - proc.start() +async def event_app(event_store: SimpleEventStore) -> AsyncIterator[tuple[SimpleEventStore, Starlette]]: + """The test server's app with an event store and retry_interval enabled.""" + async with running_app(event_store=event_store, retry_interval=500) as app: + yield event_store, app - # Wait for server to be running - wait_for_server(event_server_port) - yield event_store, f"http://127.0.0.1:{event_server_port}" - - # Clean up - proc.kill() - proc.join(timeout=2) +# Basic request validation tests +@pytest.mark.anyio +async def test_accept_header_validation(basic_app: Starlette) -> None: + """A POST without an Accept header is rejected with 406.""" + async with make_client(basic_app) as client: + # Suppress the httpx client default Accept: */* header + del client.headers["accept"] + response = await client.post( + "/mcp", + headers={"Content-Type": "application/json"}, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text -@pytest.fixture -def json_response_server(json_server_port: int) -> Generator[None, None, None]: - """Start a server with JSON response enabled.""" - proc = multiprocessing.Process( - target=run_server, - kwargs={"port": json_server_port, "is_json_response_enabled": True}, - daemon=True, - ) - proc.start() +@pytest.mark.anyio +@pytest.mark.parametrize( + "accept_header", + [ + "text/html", + "application/*", + "text/*", + ], +) +async def test_accept_header_incompatible(basic_app: Starlette, accept_header: str) -> None: + """Accept headers that do not literally include both required media types are rejected for SSE mode. - # Wait for server to be running - wait_for_server(json_server_port) + (v1 matches Accept media types literally; wildcard support is a main-only change, #2152.) + """ + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": accept_header, + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text - yield - # Clean up - proc.kill() - proc.join(timeout=2) +@pytest.mark.anyio +async def test_content_type_validation(basic_app: Starlette) -> None: + """A POST whose Content-Type is not application/json is rejected with 400.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "text/plain", + }, + content="This is not JSON", + ) + assert response.status_code == 400 + assert "Invalid Content-Type" in response.text -@pytest.fixture -def basic_server_url(basic_server_port: int) -> str: - """Get the URL for the basic test server.""" - return f"http://127.0.0.1:{basic_server_port}" +@pytest.mark.anyio +async def test_json_validation(basic_app: Starlette) -> None: + """A POST body that is not valid JSON is rejected with a parse error.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + content="this is not valid json", + ) + assert response.status_code == 400 + assert "Parse error" in response.text -@pytest.fixture -def json_server_url(json_server_port: int) -> str: - """Get the URL for the JSON response test server.""" - return f"http://127.0.0.1:{json_server_port}" +@pytest.mark.anyio +async def test_json_parsing(basic_app: Starlette) -> None: + """Valid JSON that is not a JSON-RPC message is rejected with a validation error.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"foo": "bar"}, + ) + assert response.status_code == 400 + assert "Validation error" in response.text -# Basic request validation tests -def test_accept_header_validation(basic_server: None, basic_server_url: str): - """Test that Accept header is properly validated.""" - # Test without Accept header - response = requests.post( - f"{basic_server_url}/mcp", - headers={"Content-Type": "application/json"}, - json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text - - -def test_content_type_validation(basic_server: None, basic_server_url: str): - """Test that Content-Type header is properly validated.""" - # Test with incorrect Content-Type - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "text/plain", - }, - data="This is not JSON", - ) - assert response.status_code == 400 - assert "Invalid Content-Type" in response.text +@pytest.mark.anyio +async def test_method_not_allowed(basic_app: Starlette) -> None: + """Unsupported HTTP methods are rejected with 405.""" + async with make_client(basic_app) as client: + response = await client.put( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 405 + assert "Method Not Allowed" in response.text -def test_json_validation(basic_server: None, basic_server_url: str): - """Test that JSON content is properly validated.""" - # Test with invalid JSON - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - data="this is not valid json", - ) - assert response.status_code == 400 - assert "Parse error" in response.text - - -def test_json_parsing(basic_server: None, basic_server_url: str): - """Test that JSON content is properly parse.""" - # Test with valid JSON but invalid JSON-RPC - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json={"foo": "bar"}, - ) - assert response.status_code == 400 - assert "Validation error" in response.text - - -def test_method_not_allowed(basic_server: None, basic_server_url: str): - """Test that unsupported HTTP methods are rejected.""" - # Test with unsupported method (PUT) - response = requests.put( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, - ) - assert response.status_code == 405 - assert "Method Not Allowed" in response.text - - -def test_session_validation(basic_server: None, basic_server_url: str): - """Test session ID validation.""" - # session_id not used directly in this test - - # Test without session ID - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json={"jsonrpc": "2.0", "method": "list_tools", "id": 1}, - ) - assert response.status_code == 400 - assert "Missing session ID" in response.text +@pytest.mark.anyio +async def test_session_validation(basic_app: Starlette) -> None: + """A non-initialize request without a session ID is rejected with 400.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"jsonrpc": "2.0", "method": "list_tools", "id": 1}, + ) + assert response.status_code == 400 + assert "Missing session ID" in response.text -def test_session_id_pattern(): - """Test that SESSION_ID_PATTERN correctly validates session IDs.""" +def test_session_id_pattern() -> None: + """SESSION_ID_PATTERN accepts visible ASCII (0x21-0x7E) and rejects everything else.""" # Valid session IDs (visible ASCII characters from 0x21 to 0x7E) valid_session_ids = [ "test-session-id", @@ -696,8 +537,8 @@ def test_session_id_pattern(): assert SESSION_ID_PATTERN.fullmatch(session_id) is None -def test_streamable_http_transport_init_validation(): - """Test that StreamableHTTPServerTransport validates session ID on init.""" +def test_streamable_http_transport_init_validation() -> None: + """StreamableHTTPServerTransport accepts valid or absent session IDs and rejects invalid ones.""" # Valid session ID should initialize without errors valid_transport = StreamableHTTPServerTransport(mcp_session_id="valid-id") assert valid_transport.mcp_session_id == "valid-id" @@ -719,299 +560,265 @@ def test_streamable_http_transport_init_validation(): StreamableHTTPServerTransport(mcp_session_id="test\n") -def test_session_termination(basic_server: None, basic_server_url: str): - """Test session termination via DELETE and subsequent request handling.""" - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 +@pytest.mark.anyio +async def test_session_termination(basic_app: Starlette) -> None: + """DELETE terminates the session, after which requests for it return 404.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + + # Extract negotiated protocol version from SSE response + negotiated_version = extract_protocol_version_from_sse(response) + + # Now terminate the session + session_id = response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + response = await client.delete( + "/mcp", + headers={ + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + ) + assert response.status_code == 200 + + # Try to use the terminated session + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + }, + json={"jsonrpc": "2.0", "method": "ping", "id": 2}, + ) + assert response.status_code == 404 + assert "Session has been terminated" in response.text - # Extract negotiated protocol version from SSE response - negotiated_version = extract_protocol_version_from_sse(response) - # Now terminate the session - session_id = response.headers.get(MCP_SESSION_ID_HEADER) - response = requests.delete( - f"{basic_server_url}/mcp", - headers={ - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - ) - assert response.status_code == 200 - - # Try to use the terminated session - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - }, - json={"jsonrpc": "2.0", "method": "ping", "id": 2}, - ) - assert response.status_code == 404 - assert "Session has been terminated" in response.text - - -def test_response(basic_server: None, basic_server_url: str): - """Test response handling for a valid request.""" - mcp_url = f"{basic_server_url}/mcp" - response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 +@pytest.mark.anyio +async def test_response(basic_app: Starlette) -> None: + """A request on an initialized session is answered on a text/event-stream response.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + + # Extract negotiated protocol version from SSE response + negotiated_version = extract_protocol_version_from_sse(response) + + # Now get the session ID + session_id = response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + # Try to use the session with proper headers + async with client.stream( + "POST", + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"}, + ) as tools_response: + assert tools_response.status_code == 200 + assert tools_response.headers.get("Content-Type") == "text/event-stream" - # Extract negotiated protocol version from SSE response - negotiated_version = extract_protocol_version_from_sse(response) - # Now get the session ID - session_id = response.headers.get(MCP_SESSION_ID_HEADER) +@pytest.mark.anyio +async def test_json_response(json_app: Starlette) -> None: + """With JSON response mode enabled, requests are answered with application/json bodies.""" + async with make_client(json_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/json" - # Try to use the session with proper headers - tools_response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"}, - stream=True, - ) - assert tools_response.status_code == 200 - assert tools_response.headers.get("Content-Type") == "text/event-stream" - - -def test_json_response(json_response_server: None, json_server_url: str): - """Test response handling when is_json_response_enabled is True.""" - mcp_url = f"{json_server_url}/mcp" - response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 - assert response.headers.get("Content-Type") == "application/json" - - -def test_json_response_accept_json_only(json_response_server: None, json_server_url: str): - """Test that json_response servers only require application/json in Accept header.""" - mcp_url = f"{json_server_url}/mcp" - response = requests.post( - mcp_url, - headers={ - "Accept": "application/json", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 - assert response.headers.get("Content-Type") == "application/json" - - -def test_json_response_missing_accept_header(json_response_server: None, json_server_url: str): - """Test that json_response servers reject requests without Accept header.""" - mcp_url = f"{json_server_url}/mcp" - response = requests.post( - mcp_url, - headers={ - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text +@pytest.mark.anyio +async def test_json_response_accept_json_only(json_app: Starlette) -> None: + """JSON response mode only requires application/json in the Accept header.""" + async with make_client(json_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/json" -def test_json_response_incorrect_accept_header(json_response_server: None, json_server_url: str): - """Test that json_response servers reject requests with incorrect Accept header.""" - mcp_url = f"{json_server_url}/mcp" - # Test with only text/event-stream (wrong for JSON server) - response = requests.post( - mcp_url, - headers={ - "Accept": "text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text - - -def test_get_sse_stream(basic_server: None, basic_server_url: str): - """Test establishing an SSE stream via GET request.""" - # First, we need to initialize a session - mcp_url = f"{basic_server_url}/mcp" - init_response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert init_response.status_code == 200 - # Get the session ID - session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) - assert session_id is not None +@pytest.mark.anyio +async def test_json_response_missing_accept_header(json_app: Starlette) -> None: + """JSON response mode still rejects requests without an Accept header.""" + async with make_client(json_app) as client: + # Suppress the httpx client default Accept: */* header + del client.headers["accept"] + response = await client.post( + "/mcp", + headers={ + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text - # Extract negotiated protocol version from SSE response - init_data = None - assert init_response.headers.get("Content-Type") == "text/event-stream" - for line in init_response.text.splitlines(): # pragma: no branch - if line.startswith("data: "): # pragma: no cover - init_data = json.loads(line[6:]) - break - assert init_data is not None - negotiated_version = init_data["result"]["protocolVersion"] - - # Now attempt to establish an SSE stream via GET - get_response = requests.get( - mcp_url, - headers={ - "Accept": "text/event-stream", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - stream=True, - ) - # Verify we got a successful response with the right content type - assert get_response.status_code == 200 - assert get_response.headers.get("Content-Type") == "text/event-stream" +@pytest.mark.anyio +async def test_json_response_incorrect_accept_header(json_app: Starlette) -> None: + """JSON response mode rejects an Accept header that does not cover application/json.""" + async with make_client(json_app) as client: + # Test with only text/event-stream (wrong for JSON server) + response = await client.post( + "/mcp", + headers={ + "Accept": "text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text - # Test that a second GET request gets rejected (only one stream allowed) - second_get = requests.get( - mcp_url, - headers={ - "Accept": "text/event-stream", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - stream=True, - ) - # Should get CONFLICT (409) since there's already a stream - # Note: This might fail if the first stream fully closed before this runs, - # but generally it should work in the test environment where it runs quickly - assert second_get.status_code == 409 - - -def test_get_validation(basic_server: None, basic_server_url: str): - """Test validation for GET requests.""" - # First, we need to initialize a session - mcp_url = f"{basic_server_url}/mcp" - init_response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert init_response.status_code == 200 +@pytest.mark.anyio +async def test_get_sse_stream(basic_app: Starlette) -> None: + """GET establishes the standalone SSE stream, and a second GET is rejected with 409.""" + async with make_client(basic_app) as client: + # First, we need to initialize a session + init_response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 - # Get the session ID - session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) - assert session_id is not None + # Get the session ID + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + negotiated_version = extract_protocol_version_from_sse(init_response) - # Extract negotiated protocol version from SSE response - init_data = None - assert init_response.headers.get("Content-Type") == "text/event-stream" - for line in init_response.text.splitlines(): # pragma: no branch - if line.startswith("data: "): # pragma: no cover - init_data = json.loads(line[6:]) - break - assert init_data is not None - negotiated_version = init_data["result"]["protocolVersion"] - - # Test without Accept header - response = requests.get( - mcp_url, - headers={ - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - stream=True, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text - - # Test with wrong Accept header - response = requests.get( - mcp_url, - headers={ - "Accept": "application/json", + # Now attempt to establish an SSE stream via GET + get_headers = { + "Accept": "text/event-stream", MCP_SESSION_ID_HEADER: session_id, MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text + } + # The streams enter in order, so the second GET arrives while the first is held open. + async with ( + client.stream("GET", "/mcp", headers=get_headers) as get_response, + client.stream("GET", "/mcp", headers=get_headers) as second_get, + ): + # Verify we got a successful response with the right content type + assert get_response.status_code == 200 + assert get_response.headers.get("Content-Type") == "text/event-stream" + # The second GET gets CONFLICT (409): only one standalone stream is allowed per session. + assert second_get.status_code == 409 -# Client-specific fixtures -@pytest.fixture -async def http_client(basic_server: None, basic_server_url: str): # pragma: no cover - """Create test client matching the SSE test pattern.""" - async with httpx.AsyncClient(base_url=basic_server_url) as client: - yield client + +@pytest.mark.anyio +async def test_get_validation(basic_app: Starlette) -> None: + """A GET without an Accept header covering text/event-stream is rejected with 406.""" + async with make_client(basic_app) as client: + # First, we need to initialize a session + init_response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + + # Get the session ID + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + negotiated_version = extract_protocol_version_from_sse(init_response) + + # Test without Accept header (suppress the httpx client default Accept: */*) + del client.headers["accept"] + response = await client.get( + "/mcp", + headers={ + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + # Test with wrong Accept header + response = await client.get( + "/mcp", + headers={ + "Accept": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text +# Client-specific fixtures @pytest.fixture -async def initialized_client_session(basic_server: None, basic_server_url: str): +async def initialized_client_session(basic_app: Starlette) -> AsyncIterator[ClientSession]: """Create initialized StreamableHTTP client session.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream) as session, ): - async with ClientSession( - read_stream, - write_stream, - ) as session: - await session.initialize() - yield session + await session.initialize() + yield session @pytest.mark.anyio -async def test_streamable_http_client_basic_connection(basic_server: None, basic_server_url: str): - """Test basic client connection with initialization.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as ( - read_stream, - write_stream, - _, +async def test_streamable_http_client_basic_connection(basic_app: Starlette) -> None: + """A client initializes against a server over the StreamableHTTP transport.""" + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream) as session, ): - async with ClientSession( - read_stream, - write_stream, - ) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.serverInfo.name == SERVER_NAME + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME @pytest.mark.anyio -async def test_streamable_http_client_resource_read(initialized_client_session: ClientSession): - """Test client resource read functionality.""" +async def test_streamable_http_client_resource_read(initialized_client_session: ClientSession) -> None: + """A resource read round-trips its arguments and the handler's content.""" response = await initialized_client_session.read_resource(uri=AnyUrl("foobar://test-resource")) assert len(response.contents) == 1 assert response.contents[0].uri == AnyUrl("foobar://test-resource") @@ -1020,11 +827,11 @@ async def test_streamable_http_client_resource_read(initialized_client_session: @pytest.mark.anyio -async def test_streamable_http_client_tool_invocation(initialized_client_session: ClientSession): - """Test client tool invocation.""" +async def test_streamable_http_client_tool_invocation(initialized_client_session: ClientSession) -> None: + """A tool call reaches the handler and returns its content.""" # First list tools tools = await initialized_client_session.list_tools() - assert len(tools.tools) == 10 + assert len(tools.tools) == 8 assert tools.tools[0].name == "test_tool" # Call the tool @@ -1035,8 +842,8 @@ async def test_streamable_http_client_tool_invocation(initialized_client_session @pytest.mark.anyio -async def test_streamable_http_client_error_handling(initialized_client_session: ClientSession): - """Test error handling in client.""" +async def test_streamable_http_client_error_handling(initialized_client_session: ClientSession) -> None: + """A server-side error reaches the client as an McpError with the handler's message.""" with pytest.raises(McpError) as exc_info: await initialized_client_session.read_resource(uri=AnyUrl("unknown://test-error")) assert exc_info.value.error.code == 0 @@ -1044,66 +851,56 @@ async def test_streamable_http_client_error_handling(initialized_client_session: @pytest.mark.anyio -async def test_streamable_http_client_session_persistence(basic_server: None, basic_server_url: str): - """Test that session ID persists across requests.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as ( - read_stream, - write_stream, - _, +async def test_streamable_http_client_session_persistence(basic_app: Starlette) -> None: + """The session persists across multiple requests on one connection.""" + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream) as session, ): - async with ClientSession( - read_stream, - write_stream, - ) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) - # Make multiple requests to verify session persistence - tools = await session.list_tools() - assert len(tools.tools) == 10 + # Make multiple requests to verify session persistence + tools = await session.list_tools() + assert len(tools.tools) == 8 - # Read a resource - resource = await session.read_resource(uri=AnyUrl("foobar://test-persist")) - assert isinstance(resource.contents[0], TextResourceContents) is True - content = resource.contents[0] - assert isinstance(content, TextResourceContents) - assert content.text == "Read test-persist" + # Read a resource + resource = await session.read_resource(uri=AnyUrl("foobar://test-persist")) + assert isinstance(resource.contents[0], TextResourceContents) is True + content = resource.contents[0] + assert isinstance(content, TextResourceContents) + assert content.text == "Read test-persist" @pytest.mark.anyio -async def test_streamable_http_client_json_response(json_response_server: None, json_server_url: str): - """Test client with JSON response mode.""" - async with streamable_http_client(f"{json_server_url}/mcp") as ( - read_stream, - write_stream, - _, +async def test_streamable_http_client_json_response(json_app: Starlette) -> None: + """The client works identically against a server in JSON response mode.""" + async with ( + make_client(json_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream) as session, ): - async with ClientSession( - read_stream, - write_stream, - ) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.serverInfo.name == SERVER_NAME + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME - # Check tool listing - tools = await session.list_tools() - assert len(tools.tools) == 10 + # Check tool listing + tools = await session.list_tools() + assert len(tools.tools) == 8 - # Call a tool and verify JSON response handling - result = await session.call_tool("test_tool", {}) - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert result.content[0].text == "Called test_tool" + # Call a tool and verify JSON response handling + result = await session.call_tool("test_tool", {}) + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert result.content[0].text == "Called test_tool" @pytest.mark.anyio -async def test_streamable_http_client_get_stream(basic_server: None, basic_server_url: str): - """Test GET stream functionality for server-initiated messages.""" - import mcp.types as types - +async def test_streamable_http_client_get_stream(basic_app: Starlette) -> None: + """A server-initiated notification reaches the client on the standalone GET stream.""" notifications_received: list[types.ServerNotification] = [] # Define message handler to capture notifications @@ -1113,79 +910,91 @@ async def message_handler( # pragma: no branch if isinstance(message, types.ServerNotification): # pragma: no branch notifications_received.append(message) - async with streamable_http_client(f"{basic_server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, ): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - # Initialize the session - this triggers the GET stream setup - result = await session.initialize() - assert isinstance(result, InitializeResult) + # Initialize the session - this triggers the GET stream setup + result = await session.initialize() + assert isinstance(result, InitializeResult) - # Call the special tool that sends a notification - await session.call_tool("test_tool_with_standalone_notification", {}) + # Call the special tool that sends a notification + await session.call_tool("test_tool_with_standalone_notification", {}) - # Verify we received the notification - assert len(notifications_received) > 0 + # Verify we received the notification + assert len(notifications_received) > 0 - # Verify the notification is a ResourceUpdatedNotification - resource_update_found = False - for notif in notifications_received: - if isinstance(notif.root, types.ResourceUpdatedNotification): # pragma: no branch - assert str(notif.root.params.uri) == "http://test_resource/" - resource_update_found = True + # Verify the notification is a ResourceUpdatedNotification + resource_update_found = False + for notif in notifications_received: + if isinstance(notif.root, types.ResourceUpdatedNotification): # pragma: no branch + assert str(notif.root.params.uri) == "http://test_resource/" + resource_update_found = True - assert resource_update_found, "ResourceUpdatedNotification not received via GET stream" + assert resource_update_found, "ResourceUpdatedNotification not received via GET stream" + + +def create_session_id_capturing_client(app: Starlette) -> tuple[httpx.AsyncClient, list[str]]: + """Create an in-process httpx client that captures the session ID from responses.""" + captured_ids: list[str] = [] + + async def capture_session_id(response: httpx.Response) -> None: + session_id = response.headers.get(MCP_SESSION_ID_HEADER) + if session_id: + captured_ids.append(session_id) + + client = httpx.AsyncClient( + transport=StreamingASGITransport(app), + base_url=BASE_URL, + follow_redirects=True, + event_hooks={"response": [capture_session_id]}, + ) + return client, captured_ids @pytest.mark.anyio -async def test_streamable_http_client_session_termination(basic_server: None, basic_server_url: str): - """Test client session termination functionality.""" +async def test_streamable_http_client_session_termination(basic_app: Starlette) -> None: + """After the client terminates its session on close, a new connection with that session ID fails.""" + # Use httpx client with event hooks to capture session ID + httpx_client, captured_ids = create_session_id_capturing_client(basic_app) - captured_session_id = None + async with httpx_client: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert len(captured_ids) > 0 + captured_session_id = captured_ids[0] + assert captured_session_id is not None + headers = {MCP_SESSION_ID_HEADER: captured_session_id} - # Create the streamable_http_client with a custom httpx client to capture headers - async with streamable_http_client(f"{basic_server_url}/mcp") as ( - read_stream, - write_stream, - get_session_id, - ): - async with ClientSession(read_stream, write_stream) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - captured_session_id = get_session_id() - assert captured_session_id is not None - - # Make a request to confirm session is working - tools = await session.list_tools() - assert len(tools.tools) == 10 - - headers: dict[str, str] = {} # pragma: no cover - if captured_session_id: # pragma: no cover - headers[MCP_SESSION_ID_HEADER] = captured_session_id - - async with create_mcp_http_client(headers=headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + # Make a request to confirm session is working + tools = await session.list_tools() + assert len(tools.tools) == 8 + + async with make_client(basic_app, headers=headers) as httpx_client2: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client2) as ( read_stream, write_stream, _, ): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch # Attempt to make a request after termination - with pytest.raises( # pragma: no branch - McpError, - match="Session terminated", - ): + with pytest.raises(McpError, match="Session terminated"): # pragma: no branch await session.list_tools() @pytest.mark.anyio async def test_streamable_http_client_session_termination_204( - basic_server: None, basic_server_url: str, monkeypatch: pytest.MonkeyPatch -): - """Test client session termination functionality with a 204 response. + basic_app: Starlette, monkeypatch: pytest.MonkeyPatch +) -> None: + """Session termination also succeeds when the server answers the DELETE with 204. This test patches the httpx client to return a 204 response for DELETEs. """ @@ -1210,55 +1019,50 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt # Apply the patch to the httpx client monkeypatch.setattr(httpx.AsyncClient, "delete", mock_delete) - captured_session_id = None + # Use httpx client with event hooks to capture session ID + httpx_client, captured_ids = create_session_id_capturing_client(basic_app) - # Create the streamable_http_client with a custom httpx client to capture headers - async with streamable_http_client(f"{basic_server_url}/mcp") as ( - read_stream, - write_stream, - get_session_id, - ): - async with ClientSession(read_stream, write_stream) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - captured_session_id = get_session_id() - assert captured_session_id is not None - - # Make a request to confirm session is working - tools = await session.list_tools() - assert len(tools.tools) == 10 - - headers: dict[str, str] = {} # pragma: no cover - if captured_session_id: # pragma: no cover - headers[MCP_SESSION_ID_HEADER] = captured_session_id - - async with create_mcp_http_client(headers=headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with httpx_client: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert len(captured_ids) > 0 + captured_session_id = captured_ids[0] + assert captured_session_id is not None + headers = {MCP_SESSION_ID_HEADER: captured_session_id} + + # Make a request to confirm session is working + tools = await session.list_tools() + assert len(tools.tools) == 8 + + async with make_client(basic_app, headers=headers) as httpx_client2: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client2) as ( read_stream, write_stream, _, ): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch # Attempt to make a request after termination - with pytest.raises( # pragma: no branch - McpError, - match="Session terminated", - ): + with pytest.raises(McpError, match="Session terminated"): # pragma: no branch await session.list_tools() @pytest.mark.anyio -async def test_streamable_http_client_resumption(event_server: tuple[SimpleEventStore, str]): - """Test client session resumption using sync primitives for reliable coordination.""" - _, server_url = event_server +async def test_streamable_http_client_resumption(event_app: tuple[SimpleEventStore, Starlette]) -> None: + """A second client resumes an interrupted request with a resumption token and receives the rest.""" + _, app = event_app # Variables to track the state - captured_session_id = None - captured_resumption_token = None + captured_resumption_token: str | None = None captured_notifications: list[types.ServerNotification] = [] - captured_protocol_version = None - first_notification_received = False + first_notification_received = anyio.Event() + resumption_token_received = anyio.Event() async def message_handler( # pragma: no branch message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, @@ -1268,83 +1072,88 @@ async def message_handler( # pragma: no branch # Look for our first notification if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch if message.root.params.data == "First notification before lock": - nonlocal first_notification_received - first_notification_received = True + first_notification_received.set() async def on_resumption_token_update(token: str) -> None: nonlocal captured_resumption_token captured_resumption_token = token + resumption_token_received.set() + + # Use httpx client with event hooks to capture session ID + httpx_client, captured_ids = create_session_id_capturing_client(app) # First, start the client session and begin the tool that waits on lock - async with streamable_http_client(f"{server_url}/mcp", terminate_on_close=False) as ( - read_stream, - write_stream, - get_session_id, - ): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - captured_session_id = get_session_id() - assert captured_session_id is not None - # Capture the negotiated protocol version - captured_protocol_version = result.protocolVersion - - # Start the tool that will wait on lock in a task - async with anyio.create_task_group() as tg: - - async def run_tool(): - metadata = ClientMessageMetadata( - on_resumption_token_update=on_resumption_token_update, - ) - await session.send_request( - types.ClientRequest( - types.CallToolRequest( - params=types.CallToolRequestParams( - name="wait_for_lock_with_notification", arguments={} - ), - ) - ), - types.CallToolResult, - metadata=metadata, - ) + async with httpx_client: + async with streamable_http_client(f"{BASE_URL}/mcp", terminate_on_close=False, http_client=httpx_client) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( # pragma: no branch + read_stream, write_stream, message_handler=message_handler + ) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert len(captured_ids) > 0 + captured_session_id = captured_ids[0] + assert captured_session_id is not None + # Build phase-2 headers now while both values are in scope + headers: dict[str, Any] = { + MCP_SESSION_ID_HEADER: captured_session_id, + MCP_PROTOCOL_VERSION_HEADER: result.protocolVersion, + } - tg.start_soon(run_tool) + # Start the tool that will wait on lock in a task + async with anyio.create_task_group() as tg: # pragma: no branch - # Wait for the first notification and resumption token - while not first_notification_received or not captured_resumption_token: - await anyio.sleep(0.1) + async def run_tool(): + metadata = ClientMessageMetadata( + on_resumption_token_update=on_resumption_token_update, + ) + await session.send_request( + types.ClientRequest( + types.CallToolRequest( + params=types.CallToolRequestParams( + name="wait_for_lock_with_notification", arguments={} + ), + ) + ), + types.CallToolResult, + metadata=metadata, + ) - # Kill the client session while tool is waiting on lock - tg.cancel_scope.cancel() + tg.start_soon(run_tool) - # Verify we received exactly one notification - assert len(captured_notifications) == 1 # pragma: no cover - assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification) # pragma: no cover - assert captured_notifications[0].root.params.data == "First notification before lock" # pragma: no cover + # Wait for the first notification and resumption token + with anyio.fail_after(5): + await first_notification_received.wait() + await resumption_token_received.wait() - # Clear notifications for the second phase - captured_notifications = [] # pragma: no cover + # first_notification_received is set by message_handler immediately + # after appending to captured_notifications. The server tool is + # blocked on its lock, so nothing else can arrive before we cancel. + assert len(captured_notifications) == 1 + assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification) + assert captured_notifications[0].root.params.data == "First notification before lock" + # Reset for phase 2 before cancelling + captured_notifications.clear() - # Now resume the session with the same mcp-session-id and protocol version - headers: dict[str, Any] = {} # pragma: no cover - if captured_session_id: # pragma: no cover - headers[MCP_SESSION_ID_HEADER] = captured_session_id - if captured_protocol_version: # pragma: no cover - headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version + # Kill the client session while tool is waiting on lock + tg.cancel_scope.cancel() - async with create_mcp_http_client(headers=headers) as httpx_client: - async with streamable_http_client(f"{server_url}/mcp", http_client=httpx_client) as ( + async with make_client(app, headers=headers) as httpx_client2: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client2) as ( read_stream, write_stream, _, ): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: + async with ClientSession( + read_stream, write_stream, message_handler=message_handler + ) as session: # pragma: no branch result = await session.send_request( types.ClientRequest( - types.CallToolRequest( - params=types.CallToolRequestParams(name="release_lock", arguments={}), - ) + types.CallToolRequest(params=types.CallToolRequestParams(name="release_lock", arguments={})) ), types.CallToolResult, ) @@ -1367,14 +1176,13 @@ async def run_tool(): # We should have received the remaining notifications assert len(captured_notifications) == 1 - - assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification) # pragma: no cover - assert captured_notifications[0].root.params.data == "Second notification after lock" # pragma: no cover + assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification) + assert captured_notifications[0].root.params.data == "Second notification after lock" @pytest.mark.anyio -async def test_streamablehttp_server_sampling(basic_server: None, basic_server_url: str): - """Test server-initiated sampling request through streamable HTTP transport.""" +async def test_streamablehttp_server_sampling(basic_app: Starlette) -> None: + """A server-initiated sampling request reaches the client callback and its result the tool.""" # Variable to track if sampling callback was invoked sampling_callback_invoked = False captured_message_params = None @@ -1401,153 +1209,99 @@ async def sampling_callback( ) # Create client with sampling callback - async with streamable_http_client(f"{basic_server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream, sampling_callback=sampling_callback) as session, ): - async with ClientSession( - read_stream, - write_stream, - sampling_callback=sampling_callback, - ) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) - # Call the tool that triggers server-side sampling - tool_result = await session.call_tool("test_sampling_tool", {}) + # Call the tool that triggers server-side sampling + tool_result = await session.call_tool("test_sampling_tool", {}) - # Verify the tool result contains the expected content - assert len(tool_result.content) == 1 - assert tool_result.content[0].type == "text" - assert "Response from sampling: Received message from server" in tool_result.content[0].text + # Verify the tool result contains the expected content + assert len(tool_result.content) == 1 + assert tool_result.content[0].type == "text" + assert "Response from sampling: Received message from server" in tool_result.content[0].text - # Verify sampling callback was invoked - assert sampling_callback_invoked - assert captured_message_params is not None - assert len(captured_message_params.messages) == 1 - assert captured_message_params.messages[0].content.text == "Server needs client sampling" + # Verify sampling callback was invoked + assert sampling_callback_invoked + assert captured_message_params is not None + assert len(captured_message_params.messages) == 1 + assert captured_message_params.messages[0].content.text == "Server needs client sampling" # Context-aware server implementation for testing request context propagation -class ContextAwareServerTest(Server): # pragma: no cover - def __init__(self): - super().__init__("ContextAwareServer") - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="echo_headers", - description="Echo request headers from context", - inputSchema={"type": "object", "properties": {}}, - ), - Tool( - name="echo_context", - description="Echo request context with custom data", - inputSchema={ - "type": "object", - "properties": { - "request_id": {"type": "string"}, - }, - "required": ["request_id"], +def _create_context_server() -> Server[dict[str, Any], Request]: + server: Server[dict[str, Any], Request] = Server("ContextAwareServer") + + @server.list_tools() + async def handle_list_tools() -> list[Tool]: + return [ + Tool( + name="echo_headers", + description="Echo request headers from context", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="echo_context", + description="Echo request context with custom data", + inputSchema={ + "type": "object", + "properties": { + "request_id": {"type": "string"}, }, - ), - ] - - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - ctx = self.request_context - - if name == "echo_headers": - # Access the request object from context - headers_info = {} - if ctx.request and isinstance(ctx.request, Request): - headers_info = dict(ctx.request.headers) - return [TextContent(type="text", text=json.dumps(headers_info))] - - elif name == "echo_context": - # Return full context information - context_data: dict[str, Any] = { - "request_id": args.get("request_id"), - "headers": {}, - "method": None, - "path": None, - } - if ctx.request and isinstance(ctx.request, Request): - request = ctx.request - context_data["headers"] = dict(request.headers) - context_data["method"] = request.method - context_data["path"] = request.url.path - return [ - TextContent( - type="text", - text=json.dumps(context_data), - ) - ] - - return [TextContent(type="text", text=f"Unknown tool: {name}")] + "required": ["request_id"], + }, + ), + ] + + @server.call_tool() + async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: + ctx = server.request_context + assert name in ("echo_headers", "echo_context") + assert isinstance(ctx.request, Request) + + if name == "echo_headers": + return [TextContent(type="text", text=json.dumps(dict(ctx.request.headers)))] + + context_data: dict[str, Any] = { + "request_id": args.get("request_id"), + "headers": dict(ctx.request.headers), + "method": ctx.request.method, + "path": ctx.request.url.path, + } + return [TextContent(type="text", text=json.dumps(context_data))] + return server -# Server runner for context-aware testing -def run_context_aware_server(port: int): # pragma: no cover - """Run the context-aware test server.""" - server = ContextAwareServerTest() +@pytest.fixture +async def context_app() -> AsyncIterator[Starlette]: + """An app whose server echoes request context, served in process.""" + server = _create_context_server() session_manager = StreamableHTTPSessionManager( app=server, - event_store=None, - json_response=False, - ) - - app = Starlette( - debug=True, - routes=[ - Mount("/mcp", app=session_manager.handle_request), - ], - lifespan=lambda app: session_manager.run(), + security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False), ) - - server_instance = uvicorn.Server( - config=uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="error", - ) - ) - server_instance.run() - - -@pytest.fixture -def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: - """Start the context-aware server in a separate process.""" - proc = multiprocessing.Process(target=run_context_aware_server, args=(basic_server_port,), daemon=True) - proc.start() - - # Wait for server to be running - wait_for_server(basic_server_port) - - yield - - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("Context-aware server process failed to terminate") + app = Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)]) + async with session_manager.run(): + yield app @pytest.mark.anyio -async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None: - """Test that request context is properly propagated through StreamableHTTP.""" +async def test_streamablehttp_request_context_propagation(context_app: Starlette) -> None: + """Custom HTTP headers on the connection are visible to server handlers via ctx.request.""" custom_headers = { "Authorization": "Bearer test-token", "X-Custom-Header": "test-value", "X-Trace-Id": "trace-123", } - async with create_mcp_http_client(headers=custom_headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with make_client(context_app, headers=custom_headers) as httpx_client: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client) as ( read_stream, write_stream, _, @@ -1572,11 +1326,11 @@ async def test_streamablehttp_request_context_propagation(context_aware_server: @pytest.mark.anyio -async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None: - """Test that request contexts are isolated between StreamableHTTP clients.""" +async def test_streamablehttp_request_context_isolation(context_app: Starlette) -> None: + """Each connection's handlers see only that connection's request headers.""" contexts: list[dict[str, Any]] = [] - # Create multiple clients with different headers + # Connect three clients in turn, each with its own headers. for i in range(3): headers = { "X-Request-Id": f"request-{i}", @@ -1584,8 +1338,8 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No "Authorization": f"Bearer token-{i}", } - async with create_mcp_http_client(headers=headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with make_client(context_app, headers=headers) as httpx_client: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client) as ( read_stream, write_stream, _, @@ -1602,8 +1356,8 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No contexts.append(context_data) # Verify each request had its own context - assert len(contexts) == 3 # pragma: no cover - for i, ctx in enumerate(contexts): # pragma: no cover + assert len(contexts) == 3 + for i, ctx in enumerate(contexts): assert ctx["request_id"] == f"request-{i}" assert ctx["headers"].get("x-request-id") == f"request-{i}" assert ctx["headers"].get("x-custom-value") == f"value-{i}" @@ -1611,157 +1365,160 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No @pytest.mark.anyio -async def test_client_includes_protocol_version_header_after_init(context_aware_server: None, basic_server_url: str): - """Test that client includes mcp-protocol-version header after initialization.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as ( - read_stream, - write_stream, - _, +async def test_client_includes_protocol_version_header_after_init(context_app: Starlette) -> None: + """After initialization, every client request carries the negotiated protocol version header.""" + async with ( + make_client(context_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream) as session, ): - async with ClientSession(read_stream, write_stream) as session: - # Initialize and get the negotiated version - init_result = await session.initialize() - negotiated_version = init_result.protocolVersion - - # Call a tool that echoes headers to verify the header is present - tool_result = await session.call_tool("echo_headers", {}) - - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - headers_data = json.loads(tool_result.content[0].text) - - # Verify protocol version header is present - assert "mcp-protocol-version" in headers_data - assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version - - -def test_server_validates_protocol_version_header(basic_server: None, basic_server_url: str): - """Test that server returns 400 Bad Request version if header unsupported or invalid.""" - # First initialize a session to get a valid session ID - init_response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert init_response.status_code == 200 - session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) - - # Test request with invalid protocol version (should fail) - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: "invalid-version", - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-2"}, - ) - assert response.status_code == 400 - assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() - - # Test request with unsupported protocol version (should fail) - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: "1999-01-01", # Very old unsupported version - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-3"}, - ) - assert response.status_code == 400 - assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() + # Initialize and get the negotiated version + init_result = await session.initialize() + negotiated_version = init_result.protocolVersion - # Test request with valid protocol version (should succeed) - negotiated_version = extract_protocol_version_from_sse(init_response) + # Call a tool that echoes headers to verify the header is present + tool_result = await session.call_tool("echo_headers", {}) - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-4"}, - ) - assert response.status_code == 200 - - -def test_server_backwards_compatibility_no_protocol_version(basic_server: None, basic_server_url: str): - """Test server accepts requests without protocol version header.""" - # First initialize a session to get a valid session ID - init_response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert init_response.status_code == 200 - session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) - - # Test request without mcp-protocol-version header (backwards compatibility) - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-backwards-compat"}, - stream=True, - ) - assert response.status_code == 200 # Should succeed for backwards compatibility - assert response.headers.get("Content-Type") == "text/event-stream" + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + headers_data = json.loads(tool_result.content[0].text) + + # Verify protocol version header is present + assert "mcp-protocol-version" in headers_data + assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version @pytest.mark.anyio -async def test_client_crash_handled(basic_server: None, basic_server_url: str): - """Test that cases where the client crashes are handled gracefully.""" +async def test_server_validates_protocol_version_header(basic_app: Starlette) -> None: + """An invalid or unsupported protocol version header is rejected with 400; the negotiated one passes.""" + async with make_client(basic_app) as client: + # First initialize a session to get a valid session ID + init_response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + # Test request with invalid protocol version (should fail) + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: "invalid-version", + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-2"}, + ) + assert response.status_code == 400 + assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() + + # Test request with unsupported protocol version (should fail) + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: "1999-01-01", # Very old unsupported version + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-3"}, + ) + assert response.status_code == 400 + assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() + + # Test request with valid protocol version (should succeed) + negotiated_version = extract_protocol_version_from_sse(init_response) + + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-4"}, + ) + assert response.status_code == 200 + + +@pytest.mark.anyio +async def test_server_backwards_compatibility_no_protocol_version(basic_app: Starlette) -> None: + """A request without a protocol version header is accepted for backwards compatibility.""" + async with make_client(basic_app) as client: + # First initialize a session to get a valid session ID + init_response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + # Test request without mcp-protocol-version header (backwards compatibility) + async with client.stream( + "POST", + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-backwards-compat"}, + ) as response: + assert response.status_code == 200 # Should succeed for backwards compatibility + assert response.headers.get("Content-Type") == "text/event-stream" + + +@pytest.mark.anyio +async def test_client_crash_handled(basic_app: Starlette) -> None: + """A client crashing mid-session does not prevent later clients from connecting.""" # Simulate bad client that crashes after init async def bad_client(): """Client that triggers ClosedResourceError""" - async with streamable_http_client(f"{basic_server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream) as session, ): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - raise Exception("client crash") + await session.initialize() + raise Exception("client crash") - # Run bad client a few times to trigger the crash + # Run bad client a few times to trigger the crash. The crash surfaces wrapped in exception + # groups whose exact shape is not the subject here — what matters is that the server survives. for _ in range(3): try: await bad_client() except Exception: pass - await anyio.sleep(0.1) # Try a good client, it should still be able to connect and list tools - async with streamable_http_client(f"{basic_server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream) as session, ): - async with ClientSession(read_stream, write_stream) as session: - result = await session.initialize() - assert isinstance(result, InitializeResult) - tools = await session.list_tools() - assert tools.tools + result = await session.initialize() + assert isinstance(result, InitializeResult) + tools = await session.list_tools() + assert tools.tools @pytest.mark.anyio -async def test_handle_sse_event_skips_empty_data(): - """Test that _handle_sse_event skips empty SSE data (keep-alive pings).""" +async def test_handle_sse_event_skips_empty_data() -> None: + """_handle_sse_event skips empty SSE data (keep-alive pings) without writing to the stream.""" transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") # Create a mock SSE event with empty data (keep-alive ping) @@ -1786,8 +1543,8 @@ async def test_handle_sse_event_skips_empty_data(): @pytest.mark.anyio -async def test_priming_event_not_sent_for_old_protocol_version(): - """Test that _maybe_send_priming_event skips for old protocol versions (backwards compat).""" +async def test_priming_event_not_sent_for_old_protocol_version() -> None: + """_maybe_send_priming_event skips for old protocol versions (backwards compat).""" # Create a transport with an event store transport = StreamableHTTPServerTransport( "/mcp", @@ -1815,8 +1572,8 @@ async def test_priming_event_not_sent_for_old_protocol_version(): @pytest.mark.anyio -async def test_priming_event_not_sent_without_event_store(): - """Test that _maybe_send_priming_event returns early when no event_store is configured.""" +async def test_priming_event_not_sent_without_event_store() -> None: + """_maybe_send_priming_event returns early when no event_store is configured.""" # Create a transport WITHOUT an event store transport = StreamableHTTPServerTransport("/mcp") @@ -1835,8 +1592,8 @@ async def test_priming_event_not_sent_without_event_store(): @pytest.mark.anyio -async def test_priming_event_includes_retry_interval(): - """Test that _maybe_send_priming_event includes retry field when retry_interval is set.""" +async def test_priming_event_includes_retry_interval() -> None: + """_maybe_send_priming_event includes the retry field when retry_interval is set.""" # Create a transport with an event store AND retry_interval transport = StreamableHTTPServerTransport( "/mcp", @@ -1864,8 +1621,8 @@ async def test_priming_event_includes_retry_interval(): @pytest.mark.anyio -async def test_close_sse_stream_callback_not_provided_for_old_protocol_version(): - """Test that close_sse_stream callbacks are NOT provided for old protocol versions.""" +async def test_close_sse_stream_callback_not_provided_for_old_protocol_version() -> None: + """close_sse_stream callbacks are only provided for protocol versions that support polling.""" # Create a transport with an event store transport = StreamableHTTPServerTransport( "/mcp", @@ -1897,83 +1654,78 @@ async def test_close_sse_stream_callback_not_provided_for_old_protocol_version() @pytest.mark.anyio async def test_streamable_http_client_receives_priming_event( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Client should receive priming event (resumption token update) on POST SSE stream.""" - _, server_url = event_server + _, app = event_app captured_resumption_tokens: list[str] = [] async def on_resumption_token_update(token: str) -> None: captured_resumption_tokens.append(token) - async with streamable_http_client(f"{server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream) as session, ): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() + await session.initialize() - # Call tool with resumption token callback via send_request - metadata = ClientMessageMetadata( - on_resumption_token_update=on_resumption_token_update, - ) - result = await session.send_request( - types.ClientRequest( - types.CallToolRequest( - params=types.CallToolRequestParams(name="test_tool", arguments={}), - ) - ), - types.CallToolResult, - metadata=metadata, - ) - assert result is not None - - # Should have received priming event token BEFORE response data - # Priming event = 1 token (empty data, id only) - # Response = 1 token (actual JSON-RPC response) - # Total = 2 tokens minimum - assert len(captured_resumption_tokens) >= 2, ( - f"Server must send priming event before response. " - f"Expected >= 2 tokens (priming + response), got {len(captured_resumption_tokens)}" - ) - assert captured_resumption_tokens[0] is not None + # Call tool with resumption token callback via send_request + metadata = ClientMessageMetadata( + on_resumption_token_update=on_resumption_token_update, + ) + result = await session.send_request( + types.ClientRequest( + types.CallToolRequest(params=types.CallToolRequestParams(name="test_tool", arguments={})) + ), + types.CallToolResult, + metadata=metadata, + ) + assert result is not None + + # Should have received priming event token BEFORE response data + # Priming event = 1 token (empty data, id only) + # Response = 1 token (actual JSON-RPC response) + # Total = 2 tokens minimum + assert len(captured_resumption_tokens) >= 2, ( + f"Server must send priming event before response. " + f"Expected >= 2 tokens (priming + response), got {len(captured_resumption_tokens)}" + ) + assert captured_resumption_tokens[0] is not None @pytest.mark.anyio async def test_server_close_sse_stream_via_context( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Server tool can call ctx.close_sse_stream() to close connection.""" - _, server_url = event_server + _, app = event_app - async with streamable_http_client(f"{server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream) as session, ): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() + await session.initialize() - # Call tool that closes stream mid-operation - # This should NOT raise NotImplementedError when fully implemented - result = await session.call_tool("tool_with_stream_close", {}) + # Call tool that closes stream mid-operation + result = await session.call_tool("tool_with_stream_close", {}) - # Client should still receive complete response (via auto-reconnect) - assert result is not None - assert len(result.content) > 0 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Done" + # Client should still receive complete response (via auto-reconnect) + assert result is not None + assert len(result.content) > 0 + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" @pytest.mark.anyio async def test_streamable_http_client_auto_reconnects( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Client should auto-reconnect with Last-Event-ID when server closes after priming event.""" - _, server_url = event_server + _, app = event_app captured_notifications: list[str] = [] async def message_handler( @@ -1985,71 +1737,63 @@ async def message_handler( if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch captured_notifications.append(str(message.root.params.data)) - async with streamable_http_client(f"{server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, ): - async with ClientSession( - read_stream, - write_stream, - message_handler=message_handler, - ) as session: - await session.initialize() - - # Call tool that: - # 1. Sends notification - # 2. Closes SSE stream - # 3. Sends more notifications (stored in event_store) - # 4. Returns response - result = await session.call_tool("tool_with_stream_close", {}) - - # Client should have auto-reconnected and received ALL notifications - assert len(captured_notifications) >= 2, ( - "Client should auto-reconnect and receive notifications sent both before and after stream close" - ) - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Done" + await session.initialize() + + # Call tool that: + # 1. Sends notification + # 2. Closes SSE stream + # 3. Sends more notifications (stored in event_store) + # 4. Returns response + result = await session.call_tool("tool_with_stream_close", {}) + + # Client should have auto-reconnected and received ALL notifications + assert len(captured_notifications) >= 2, ( + "Client should auto-reconnect and receive notifications sent both before and after stream close" + ) + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" @pytest.mark.anyio async def test_streamable_http_client_respects_retry_interval( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Client MUST respect retry field, waiting specified ms before reconnecting.""" - _, server_url = event_server + _, app = event_app - async with streamable_http_client(f"{server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream) as session, ): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() + await session.initialize() - start_time = time.monotonic() - result = await session.call_tool("tool_with_stream_close", {}) - elapsed = time.monotonic() - start_time + start_time = time.monotonic() + result = await session.call_tool("tool_with_stream_close", {}) + elapsed = time.monotonic() - start_time - # Verify result was received - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Done" + # Verify result was received + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" - # The elapsed time should include at least the retry interval - # if reconnection occurred. This test may be flaky depending on - # implementation details, but demonstrates the expected behavior. - # Note: This assertion may need adjustment based on actual implementation - assert elapsed >= 0.4, f"Client should wait ~500ms before reconnecting, but elapsed time was {elapsed:.3f}s" + # The elapsed time should include at least the retry interval (500ms) before + # the client reconnected; the tool's own work only accounts for ~100ms. + assert elapsed >= 0.4, f"Client should wait ~500ms before reconnecting, but elapsed time was {elapsed:.3f}s" @pytest.mark.anyio async def test_streamable_http_sse_polling_full_cycle( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """End-to-end test: server closes stream, client reconnects, receives all events.""" - _, server_url = event_server + _, app = event_app all_notifications: list[str] = [] async def message_handler( @@ -2061,43 +1805,38 @@ async def message_handler( if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch all_notifications.append(str(message.root.params.data)) - async with streamable_http_client(f"{server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, ): - async with ClientSession( - read_stream, - write_stream, - message_handler=message_handler, - ) as session: - await session.initialize() - - # Call tool that simulates polling pattern: - # 1. Server sends priming event - # 2. Server sends "Before close" notification - # 3. Server closes stream (calls close_sse_stream) - # 4. (client reconnects automatically) - # 5. Server sends "After close" notification - # 6. Server sends final response - result = await session.call_tool("tool_with_stream_close", {}) - - # Verify all notifications received in order - assert "Before close" in all_notifications, "Should receive notification sent before stream close" - assert "After close" in all_notifications, ( - "Should receive notification sent after stream close (via auto-reconnect)" - ) - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Done" + await session.initialize() + + # Call tool that simulates polling pattern: + # 1. Server sends priming event + # 2. Server sends "Before close" notification + # 3. Server closes stream (calls close_sse_stream) + # 4. (client reconnects automatically) + # 5. Server sends "After close" notification + # 6. Server sends final response + result = await session.call_tool("tool_with_stream_close", {}) + + # Verify all notifications received in order + assert "Before close" in all_notifications, "Should receive notification sent before stream close" + assert "After close" in all_notifications, ( + "Should receive notification sent after stream close (via auto-reconnect)" + ) + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" @pytest.mark.anyio async def test_streamable_http_events_replayed_after_disconnect( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Events sent while client is disconnected should be replayed on reconnect.""" - _, server_url = event_server + _, app = event_app notification_data: list[str] = [] async def message_handler( @@ -2109,45 +1848,43 @@ async def message_handler( if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch notification_data.append(str(message.root.params.data)) - async with streamable_http_client(f"{server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, ): - async with ClientSession( - read_stream, - write_stream, - message_handler=message_handler, - ) as session: - await session.initialize() + await session.initialize() - # Tool sends: notification1, close_stream, notification2, notification3, response - # Client should receive all notifications even though 2&3 were sent during disconnect - result = await session.call_tool("tool_with_multiple_notifications_and_close", {}) + # Tool sends: notification1, close_stream, notification2, notification3, response + # Client should receive all notifications even though 2&3 were sent during disconnect + result = await session.call_tool("tool_with_multiple_notifications_and_close", {}) - assert "notification1" in notification_data, "Should receive notification1 (sent before close)" - assert "notification2" in notification_data, "Should receive notification2 (sent after close, replayed)" - assert "notification3" in notification_data, "Should receive notification3 (sent after close, replayed)" + assert "notification1" in notification_data, "Should receive notification1 (sent before close)" + assert "notification2" in notification_data, "Should receive notification2 (sent after close, replayed)" + assert "notification3" in notification_data, "Should receive notification3 (sent after close, replayed)" - # Verify order: notification1 should come before notification2 and notification3 - idx1 = notification_data.index("notification1") - idx2 = notification_data.index("notification2") - idx3 = notification_data.index("notification3") - assert idx1 < idx2 < idx3, "Notifications should be received in order" + # Verify order: notification1 should come before notification2 and notification3 + idx1 = notification_data.index("notification1") + idx2 = notification_data.index("notification2") + idx3 = notification_data.index("notification3") + assert idx1 < idx2 < idx3, "Notifications should be received in order" - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "All notifications sent" + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "All notifications sent" @pytest.mark.anyio -async def test_streamable_http_multiple_reconnections( - event_server: tuple[SimpleEventStore, str], -): - """Verify multiple close_sse_stream() calls each trigger a client reconnect. +async def test_streamable_http_multiple_reconnections() -> None: + """Every close_sse_stream() severs a live connection and triggers its own client reconnect. - Server uses retry_interval=500ms, tool sleeps 600ms after each close to ensure - client has time to reconnect before the next checkpoint. + The tool closes its SSE stream three times; before each next cycle it waits until the + client has observed the previous cycle's two new resumption tokens (the checkpoint and the + new connection's priming event). The priming event is sent only after the server has + re-registered the resumed stream, so once the client holds its token the next close is + guaranteed to sever a live connection rather than silently no-op — making the exact token + count below a consequence of causality, not timing margins. This pins reconnect-per-close + accounting; reconnect *latency* is pinned by test_streamable_http_client_respects_retry_interval. With 3 checkpoints, we expect 8 resumption tokens: - 1 priming (initial POST connection) @@ -2155,50 +1892,77 @@ async def test_streamable_http_multiple_reconnections( - 3 priming (one per reconnect after each close) - 1 response """ - _, server_url = event_server resumption_tokens: list[str] = [] + # milestones[n] fires when the client has observed n tokens. After the initial priming + # (token 1), each completed cycle i contributes exactly two tokens — checkpoint_i and the + # reconnect's priming, in either order — so cycle i is complete at 3 + 2i tokens. + milestones = {3: anyio.Event(), 5: anyio.Event(), 7: anyio.Event()} async def on_resumption_token(token: str) -> None: resumption_tokens.append(token) - - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream, _): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - # Use send_request with metadata to track resumption tokens - metadata = ClientMessageMetadata(on_resumption_token_update=on_resumption_token) - result = await session.send_request( - types.ClientRequest( - types.CallToolRequest( - method="tools/call", - params=types.CallToolRequestParams( - name="tool_with_multiple_stream_closes", - # retry_interval=500ms, so sleep 600ms to ensure reconnect completes - arguments={"checkpoints": 3, "sleep_time": 0.6}, - ), - ) - ), - types.CallToolResult, - metadata=metadata, + milestone = milestones.get(len(resumption_tokens)) + if milestone is not None: + milestone.set() + + server: Server[dict[str, Any], Request] = Server("multi_reconnect_server") + + @server.call_tool() + async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: + ctx = server.request_context + assert name == "multi_close_tool" + for i, milestone in enumerate(milestones.values()): + await ctx.session.send_log_message( + level="info", + data=f"checkpoint_{i}", + logger="multi_close_tool", + related_request_id=ctx.request_id, ) + assert ctx.close_sse_stream is not None + await ctx.close_sse_stream() + # Client and server share one event loop, so the tool can wait directly on the + # client-side callback observing the reconnect. + with anyio.fail_after(5): + await milestone.wait() + return [TextContent(type="text", text="Completed 3 checkpoints")] + + async with ( + # retry_interval is small to keep the test fast, but nonzero so each dying connection + # finishes unwinding before its replacement registers. + running_app(event_store=SimpleEventStore(), retry_interval=50, server=server) as app, + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() + + # Use send_request with metadata to track resumption tokens + metadata = ClientMessageMetadata(on_resumption_token_update=on_resumption_token) + result = await session.send_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams(name="multi_close_tool", arguments={}), + ) + ), + types.CallToolResult, + metadata=metadata, + ) - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert "Completed 3 checkpoints" in result.content[0].text + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert "Completed 3 checkpoints" in result.content[0].text - # 4 priming + 3 notifications + 1 response = 8 tokens - assert len(resumption_tokens) == 8, ( # pragma: no cover - f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), " - f"got {len(resumption_tokens)}: {resumption_tokens}" - ) + # 4 priming + 3 notifications + 1 response = 8 tokens. All tokens are + # captured before send_request returns, so this is safe to check here. + assert len(resumption_tokens) == 8, ( + f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), " + f"got {len(resumption_tokens)}: {resumption_tokens}" + ) @pytest.mark.anyio -async def test_standalone_get_stream_reconnection( - event_server: tuple[SimpleEventStore, str], -) -> None: - """ - Test that standalone GET stream automatically reconnects after server closes it. +async def test_standalone_get_stream_reconnection(event_app: tuple[SimpleEventStore, Starlette]) -> None: + """Test that standalone GET stream automatically reconnects after server closes it. Verifies: 1. Client receives notification 1 via GET stream @@ -2206,10 +1970,10 @@ async def test_standalone_get_stream_reconnection( 3. Client reconnects with Last-Event-ID 4. Client receives notification 2 on new connection - Note: Requires event_server fixture (with event store) because close_standalone_sse_stream + Note: Requires the event store app because close_standalone_sse_stream callback is only provided when event_store is configured and protocol version >= 2025-11-25. """ - _, server_url = event_server + _, app = event_app received_notifications: list[str] = [] async def message_handler( @@ -2221,53 +1985,46 @@ async def message_handler( if isinstance(message.root, types.ResourceUpdatedNotification): # pragma: no branch received_notifications.append(str(message.root.params.uri)) - async with streamable_http_client(f"{server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, ): - async with ClientSession( - read_stream, - write_stream, - message_handler=message_handler, - ) as session: - await session.initialize() - - # Call tool that: - # 1. Sends notification_1 via GET stream - # 2. Closes standalone GET stream - # 3. Sends notification_2 (stored in event_store) - # 4. Returns response - result = await session.call_tool("tool_with_standalone_stream_close", {}) - - # Verify the tool completed - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Standalone stream close test done" - - # Verify both notifications were received - assert "http://notification_1/" in received_notifications, ( - f"Should receive notification 1 (sent before GET stream close), got: {received_notifications}" - ) - assert "http://notification_2/" in received_notifications, ( - f"Should receive notification 2 after reconnect, got: {received_notifications}" - ) + await session.initialize() + + # Call tool that: + # 1. Sends notification_1 via GET stream + # 2. Closes standalone GET stream + # 3. Sends notification_2 (stored in event_store) + # 4. Returns response + result = await session.call_tool("tool_with_standalone_stream_close", {}) + + # Verify the tool completed + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Standalone stream close test done" + + # Verify both notifications were received + assert "http://notification_1/" in received_notifications, ( + f"Should receive notification 1 (sent before GET stream close), got: {received_notifications}" + ) + assert "http://notification_2/" in received_notifications, ( + f"Should receive notification 2 after reconnect, got: {received_notifications}" + ) @pytest.mark.anyio -async def test_streamable_http_client_does_not_mutate_provided_client( - basic_server: None, basic_server_url: str -) -> None: - """Test that streamable_http_client does not mutate the provided httpx client's headers.""" +async def test_streamable_http_client_does_not_mutate_provided_client(basic_app: Starlette) -> None: + """streamable_http_client does not mutate the provided httpx client's headers.""" # Create a client with custom headers original_headers = { "X-Custom-Header": "custom-value", "Authorization": "Bearer test-token", } - async with httpx.AsyncClient(headers=original_headers, follow_redirects=True) as custom_client: + async with make_client(basic_app, headers=original_headers) as custom_client: # Use the client with streamable_http_client - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=custom_client) as ( + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=custom_client) as ( read_stream, write_stream, _, @@ -2289,22 +2046,16 @@ async def test_streamable_http_client_does_not_mutate_provided_client( @pytest.mark.anyio -async def test_streamable_http_client_mcp_headers_override_defaults( - context_aware_server: None, basic_server_url: str -) -> None: - """Test that MCP protocol headers override httpx.AsyncClient default headers.""" +async def test_streamable_http_client_mcp_headers_override_defaults(context_app: Starlette) -> None: + """MCP protocol headers override the httpx client's default headers in actual requests.""" # httpx.AsyncClient has default "accept: */*" header # We need to verify that our MCP accept header overrides it in actual requests - async with httpx.AsyncClient(follow_redirects=True) as client: + async with make_client(context_app) as client: # Verify client has default accept header assert client.headers.get("accept") == "*/*" - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as ( - read_stream, - write_stream, - _, - ): + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=client) as (read_stream, write_stream, _): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch await session.initialize() @@ -2324,22 +2075,16 @@ async def test_streamable_http_client_mcp_headers_override_defaults( @pytest.mark.anyio -async def test_streamable_http_client_preserves_custom_with_mcp_headers( - context_aware_server: None, basic_server_url: str -) -> None: - """Test that both custom headers and MCP protocol headers are sent in requests.""" +async def test_streamable_http_client_preserves_custom_with_mcp_headers(context_app: Starlette) -> None: + """Custom client headers and MCP protocol headers are both sent in requests.""" custom_headers = { "X-Custom-Header": "custom-value", "X-Request-Id": "req-123", "Authorization": "Bearer test-token", } - async with httpx.AsyncClient(headers=custom_headers, follow_redirects=True) as client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as ( - read_stream, - write_stream, - _, - ): + async with make_client(context_app, headers=custom_headers) as client: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=client) as (read_stream, write_stream, _): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch await session.initialize() @@ -2363,12 +2108,11 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers( assert headers_data["content-type"] == "application/json" -@pytest.mark.anyio -async def test_streamable_http_transport_deprecated_params_ignored(basic_server: None, basic_server_url: str) -> None: - """Test that deprecated parameters passed to StreamableHTTPTransport are properly ignored.""" +def test_streamable_http_transport_deprecated_params_ignored() -> None: + """Deprecated parameters passed to StreamableHTTPTransport are accepted but ignored.""" with pytest.warns(DeprecationWarning): transport = StreamableHTTPTransport( # pyright: ignore[reportDeprecated] - url=f"{basic_server_url}/mcp", + url=f"{BASE_URL}/mcp", headers={"X-Should-Be-Ignored": "ignored"}, timeout=999, sse_read_timeout=timedelta(seconds=999), @@ -2382,10 +2126,27 @@ async def test_streamable_http_transport_deprecated_params_ignored(basic_server: @pytest.mark.anyio -async def test_streamablehttp_client_deprecation_warning(basic_server: None, basic_server_url: str) -> None: - """Test that the old streamablehttp_client() function issues a deprecation warning.""" +async def test_streamablehttp_client_deprecation_warning(basic_app: Starlette) -> None: + """The old streamablehttp_client() function issues a deprecation warning.""" + + def in_process_client_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + return httpx.AsyncClient( + transport=StreamingASGITransport(basic_app), + base_url=BASE_URL, + headers=headers, + timeout=timeout, + auth=auth, + follow_redirects=True, + ) + with pytest.warns(DeprecationWarning, match="Use `streamable_http_client` instead"): - async with streamablehttp_client(f"{basic_server_url}/mcp") as ( # pyright: ignore[reportDeprecated] + async with streamablehttp_client( # pyright: ignore[reportDeprecated] + f"{BASE_URL}/mcp", httpx_client_factory=in_process_client_factory + ) as ( read_stream, write_stream, _, From b80a9bd7c60517d240c133b6e0190ca7ad22567e Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 13:03:17 +0000 Subject: [PATCH 09/17] Force per-test GC so leaked-stream warnings stay inside the scoped filters The module-level filterwarnings marks only apply while a test in this file is the active warning context, but v1's leaked memory streams sit in reference cycles and were garbage-collected later: at pytest's session-unconfigure unraisable sweep (every test in this file passed but pytest exited 1 when run without xdist, e.g. -n 0 for --pdb debugging), or during an unrelated test on the same worker, where the global filterwarnings = ["error"] turned the deallocator warning into a failure of that innocent test. An autouse fixture now runs gc.collect() in each test's teardown, so the deallocator warnings fire where the scoped ignores apply. Verified: every previously failing single-test and full-file -n 0 invocation now exits 0, and cross-file runs no longer misattribute the warnings. --- tests/shared/test_streamable_http.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index dd0a413af..7b337cee0 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -4,9 +4,10 @@ entirely in process. """ +import gc import json import time -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Iterator from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import timedelta @@ -70,6 +71,23 @@ pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:ResourceWarning"), ] + +@pytest.fixture(autouse=True) +def _collect_leaked_streams() -> Iterator[None]: + """Garbage-collect each test's leaked memory streams inside its own teardown. + + The filterwarnings marks above only apply while a test in this file is the + active warning context. The leaked streams sit in reference cycles, so without + a forced collection their deallocator warnings fire wherever the garbage + collector happens to run next: during an unrelated test (failing it, since the + global ``filterwarnings = ["error"]`` has no ignore there) or at pytest's + session-unconfigure unraisable sweep (exit code 1 after all tests passed when + running without xdist, e.g. ``-n 0`` for ``--pdb`` debugging). + """ + yield + gc.collect() + + # Test constants SERVER_NAME = "test_streamable_http_server" INIT_REQUEST = { From 9dc96d66b50b2ce715873bd985ec9f173ff36cba Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 13:47:50 +0000 Subject: [PATCH 10/17] Guard streamable-HTTP test modules against sse-starlette's module-global exit event tests/shared/test_streamable_http.py and tests/server/test_streamable_http_security.py serve responses in process; on sse-starlette <3.0 (what the lowest-direct CI legs install) the first EventSourceResponse binds AppStatus.should_exit_event to that test's event loop and every later SSE response in the module fails. Port the both-sides reset fixture from tests/client/test_http_unicode.py into both files. Also give the security module the same per-test gc.collect() teardown as test_streamable_http.py so leaked-stream warnings stay inside the module's scoped filters, and drop the comment claiming a -n 0 run exits nonzero after all tests pass: with the flush in place it exits 0. --- tests/server/test_streamable_http_security.py | 49 ++++++++++++++++--- tests/shared/test_streamable_http.py | 22 +++++++++ 2 files changed, 64 insertions(+), 7 deletions(-) diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index 99bb4eaec..9f4117dff 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -1,10 +1,12 @@ """Tests for StreamableHTTP server DNS rebinding protection.""" -from collections.abc import AsyncIterator +import gc +from collections.abc import AsyncIterator, Iterator from contextlib import asynccontextmanager import httpx import pytest +from sse_starlette.sse import AppStatus from starlette.applications import Starlette from starlette.routing import Mount @@ -23,18 +25,51 @@ # run in process; the old subprocess harness never observed them. The interaction suite registers # the same two scoped filters globally from tests/interaction/conftest.py (see the comment there), # but they only take effect when that package's conftest is loaded; these markers keep the tests -# that complete the initialize handshake passing in isolated runs. Markers are item-scoped, so -# they cannot cover the GC flush at session cleanup: an isolated run without xdist (`-n 0`) still -# exits nonzero after all tests pass. The default xdist runs (addopts has `-n auto`) are -# unaffected, as are full-suite runs, where the interaction conftest's ini-level filters apply. -# The filters are scoped to anyio's MemoryObject*Stream leak signature so an unrelated leak -# still fails the suite. +# that complete the initialize handshake passing in isolated runs. The filters are scoped to +# anyio's MemoryObject*Stream leak signature so an unrelated leak still fails the suite. pytestmark = [ pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:pytest.PytestUnraisableExceptionWarning"), pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:ResourceWarning"), ] +@pytest.fixture(autouse=True) +def _collect_leaked_streams() -> Iterator[None]: + """Garbage-collect each test's leaked memory streams inside its own teardown. + + The filterwarnings marks above only apply while a test in this file is the + active warning context. The leaked streams sit in reference cycles, so without + a forced collection their deallocator warnings fire wherever the garbage + collector happens to run next: during an unrelated test (failing it, since the + global ``filterwarnings = ["error"]`` has no ignore there) or at pytest's + session-unconfigure unraisable sweep (exit code 1 after all tests passed when + running without xdist, e.g. ``-n 0`` for ``--pdb`` debugging). + """ + yield + gc.collect() + + +@pytest.fixture(autouse=True) +def _reset_sse_starlette_exit_event() -> Iterator[None]: + """Reset sse-starlette's module-global exit Event around each test. + + sse-starlette <3.0 (allowed by this branch's dependency floor; CI's lowest-direct leg + installs it) stores an `anyio.Event` on the `AppStatus` class the first time an + `EventSourceResponse` runs; that Event is bound to the test's event loop and breaks every + subsequent in-process SSE response. sse-starlette 3.x switched to a ContextVar and has no + such attribute. Resetting on both sides of the test keeps this module immune to a stale + Event left behind by an earlier test on the same worker as well as cleaning up after its + own. This mirrors the autouse fixtures in tests/shared/test_sse.py and + tests/interaction/conftest.py. + """ + if hasattr(AppStatus, "should_exit_event"): # pragma: no branch + # setattr keeps pyright happy: the locked sse-starlette 3.x has no such attribute. + setattr(AppStatus, "should_exit_event", None) # pragma: lax no cover + yield + if hasattr(AppStatus, "should_exit_event"): # pragma: no branch + setattr(AppStatus, "should_exit_event", None) # pragma: lax no cover + + @asynccontextmanager async def streamable_http_security_client( security_settings: TransportSecuritySettings | None = None, diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 7b337cee0..832cbd13e 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -19,6 +19,7 @@ import pytest from httpx_sse import ServerSentEvent from pydantic import AnyUrl +from sse_starlette.sse import AppStatus from starlette.applications import Starlette from starlette.requests import Request from starlette.routing import Mount @@ -88,6 +89,27 @@ def _collect_leaked_streams() -> Iterator[None]: gc.collect() +@pytest.fixture(autouse=True) +def _reset_sse_starlette_exit_event() -> Iterator[None]: + """Reset sse-starlette's module-global exit Event around each test. + + sse-starlette <3.0 (allowed by this branch's dependency floor; CI's lowest-direct leg + installs it) stores an `anyio.Event` on the `AppStatus` class the first time an + `EventSourceResponse` runs; that Event is bound to the test's event loop and breaks every + subsequent in-process SSE response. sse-starlette 3.x switched to a ContextVar and has no + such attribute. Resetting on both sides of the test keeps this module immune to a stale + Event left behind by an earlier test on the same worker as well as cleaning up after its + own. This mirrors the autouse fixtures in tests/shared/test_sse.py and + tests/interaction/conftest.py. + """ + if hasattr(AppStatus, "should_exit_event"): # pragma: no branch + # setattr keeps pyright happy: the locked sse-starlette 3.x has no such attribute. + setattr(AppStatus, "should_exit_event", None) # pragma: lax no cover + yield + if hasattr(AppStatus, "should_exit_event"): # pragma: no branch + setattr(AppStatus, "should_exit_event", None) # pragma: lax no cover + + # Test constants SERVER_NAME = "test_streamable_http_server" INIT_REQUEST = { From 600c96e3f983362cb06ae45ea8e37baa6a543789 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 13:47:20 +0000 Subject: [PATCH 11/17] Extend the per-test GC fixture to the SSE and Unicode HTTP test modules MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit tests/shared/test_sse.py and tests/client/test_http_unicode.py carry the same scoped MemoryObject*Stream filterwarnings marks as tests/shared/test_streamable_http.py and had the same gap: the leaked streams sit in reference cycles, so their deallocator warnings fired at pytest's session-unconfigure unraisable sweep, exiting 1 after all tests passed when run without xdist (-n 0). Both files documented this in their filter-mark comments. Apply the same autouse fixture that fixed test_streamable_http.py — a gc.collect() in each test's teardown, where the item-scoped ignores still apply — and update the comments to describe the fixed behavior. Verified: -n 0 runs of each file now exit 0 (previously 26 passed/exit 1 and 2 passed/exit 1), and five consecutive default-option runs of both files together pass. --- tests/client/test_http_unicode.py | 27 ++++++++++++++++++++++----- tests/shared/test_sse.py | 27 ++++++++++++++++++++++----- 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/tests/client/test_http_unicode.py b/tests/client/test_http_unicode.py index 21097ce65..5be19ac72 100644 --- a/tests/client/test_http_unicode.py +++ b/tests/client/test_http_unicode.py @@ -5,6 +5,7 @@ (server→client and client→server) using the streamable HTTP transport. """ +import gc from collections.abc import AsyncIterator, Iterator from contextlib import asynccontextmanager from typing import Any @@ -30,17 +31,33 @@ # run in process; the old subprocess harness never observed them. The interaction suite registers # the same two scoped filters globally from tests/interaction/conftest.py (see the comment there), # but they only take effect when that package's conftest is loaded; these markers keep the tests -# themselves passing in isolated runs. Markers are item-scoped, so they cannot cover the GC -# flush at session cleanup: an isolated run without xdist (`-n 0`) still exits nonzero after all -# tests pass. The default xdist runs (addopts has `-n auto`) are unaffected, as are full-suite -# runs, where the interaction conftest's ini-level filters apply. The filters are scoped to -# anyio's MemoryObject*Stream leak signature so an unrelated leak still fails the suite. +# themselves passing in isolated runs. Markers are item-scoped, so the autouse +# `_collect_leaked_streams` fixture below garbage-collects each test's leaks inside its own +# teardown, where these filters apply; without it, leaks GC'd at session cleanup escape the +# scoped ignores. The filters are scoped to anyio's MemoryObject*Stream leak signature so an +# unrelated leak still fails the suite. pytestmark = [ pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:pytest.PytestUnraisableExceptionWarning"), pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:ResourceWarning"), ] +@pytest.fixture(autouse=True) +def _collect_leaked_streams() -> Iterator[None]: + """Garbage-collect each test's leaked memory streams inside its own teardown. + + The filterwarnings marks above only apply while a test in this file is the + active warning context. The leaked streams sit in reference cycles, so without + a forced collection their deallocator warnings fire wherever the garbage + collector happens to run next: during an unrelated test (failing it, since the + global ``filterwarnings = ["error"]`` has no ignore there) or at pytest's + session-unconfigure unraisable sweep (exit code 1 after all tests passed when + running without xdist, e.g. ``-n 0`` for ``--pdb`` debugging). + """ + yield + gc.collect() + + @pytest.fixture(autouse=True) def _reset_sse_starlette_exit_event() -> Iterator[None]: """Reset sse-starlette's module-global exit Event around each test. diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index afe93a7a4..56d03f530 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,5 +1,6 @@ """Tests for the SSE client and server transports, driven entirely in process.""" +import gc import json from collections.abc import AsyncGenerator, Iterable, Iterator from typing import Any @@ -50,17 +51,33 @@ # process; the old subprocess harness never observed them. The interaction suite registers the # same two scoped filters globally from tests/interaction/conftest.py (see the comment there), # but they only take effect when that package's conftest is loaded; these markers keep the tests -# themselves passing in isolated runs. Markers are item-scoped, so they cannot cover the GC -# flush at session cleanup: an isolated run without xdist (`-n 0`) still exits nonzero after all -# tests pass. The default xdist runs (addopts has `-n auto`) are unaffected, as are full-suite -# runs, where the interaction conftest's ini-level filters apply. The filters are scoped to -# anyio's MemoryObject*Stream leak signature so an unrelated leak still fails the suite. +# themselves passing in isolated runs. Markers are item-scoped, so the autouse +# `_collect_leaked_streams` fixture below garbage-collects each test's leaks inside its own +# teardown, where these filters apply; without it, leaks GC'd at session cleanup escape the +# scoped ignores. The filters are scoped to anyio's MemoryObject*Stream leak signature so an +# unrelated leak still fails the suite. pytestmark = [ pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:pytest.PytestUnraisableExceptionWarning"), pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:ResourceWarning"), ] +@pytest.fixture(autouse=True) +def _collect_leaked_streams() -> Iterator[None]: + """Garbage-collect each test's leaked memory streams inside its own teardown. + + The filterwarnings marks above only apply while a test in this file is the + active warning context. The leaked streams sit in reference cycles, so without + a forced collection their deallocator warnings fire wherever the garbage + collector happens to run next: during an unrelated test (failing it, since the + global ``filterwarnings = ["error"]`` has no ignore there) or at pytest's + session-unconfigure unraisable sweep (exit code 1 after all tests passed when + running without xdist, e.g. ``-n 0`` for ``--pdb`` debugging). + """ + yield + gc.collect() + + @pytest.fixture(autouse=True) def _reset_sse_starlette_exit_event() -> Iterator[None]: """Reset sse-starlette's module-global exit Event around each test. From 39938a2f8a784551221f807cef46caf6e29d60c0 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 13:53:24 +0000 Subject: [PATCH 12/17] Harden the streaming ASGI bridge against trailing responses Once an application sends the final http.response.body chunk, the bridge now drops any further http.response.start/body messages, matching what a real ASGI server's client observes. Starlette's request_response sends a trailing response when an endpoint's sub-application has already completed a rejection response, so the SSE security rejection tests no longer depend on scheduling for the client to read the first status. Pinned by two new bridge contract tests (registered as harness self-tests). Also widen the trio-leg unraisable-warning filter to the whole httpx/httpx-sse generator chain: abandoning EventSource.aiter_sse abandons the nested aiter_lines -> aiter_text -> aiter_bytes -> aiter_raw generators, and which link the finalizer reports depends on GC timing and Python version. --- tests/interaction/conftest.py | 13 +++--- tests/interaction/test_coverage.py | 2 + tests/interaction/transports/_bridge.py | 13 +++++- tests/interaction/transports/test_bridge.py | 47 +++++++++++++++++++++ 4 files changed, 69 insertions(+), 6 deletions(-) diff --git a/tests/interaction/conftest.py b/tests/interaction/conftest.py index a6dd4d797..92119cb1a 100644 --- a/tests/interaction/conftest.py +++ b/tests/interaction/conftest.py @@ -27,14 +27,17 @@ def pytest_configure(config: pytest.Config) -> None: # only test on the trio backend. v1's streamable-HTTP client abandons its httpx/httpx-sse # response generators when the session task group is cancelled at teardown; asyncio finalizes # abandoned async generators silently at loop shutdown, but trio's finalizer warns about each - # one (`Async generator ... was garbage collected before it had been exhausted`). The fixes - # live in `src/` on `main` and are out of scope for this tests-only backport. The filters are - # scoped to the two known httpx generator signatures so an unrelated leak still fails the suite. + # one (`Async generator ... was garbage collected before it had been exhausted`). Abandoning + # `EventSource.aiter_sse` abandons the whole generator chain nested under it (`aiter_lines` -> + # `aiter_text` -> `aiter_bytes` -> `aiter_raw`), and which links the finalizer reports depends + # on GC timing and Python version. The fixes live in `src/` on `main` and are out of scope for + # this tests-only backport. The filters are scoped to the httpx/httpx-sse generator signatures + # (every generator in that chain lives on `Response` or `EventSource`) so an unrelated leak + # still fails the suite. config.addinivalue_line("filterwarnings", "ignore:Async generator 'httpx:ResourceWarning") config.addinivalue_line( "filterwarnings", - "ignore:.*async_generator object (Response.aiter_text|EventSource.aiter_sse)" - ":pytest.PytestUnraisableExceptionWarning", + "ignore:.*async_generator object (Response|EventSource).aiter_:pytest.PytestUnraisableExceptionWarning", ) diff --git a/tests/interaction/test_coverage.py b/tests/interaction/test_coverage.py index 7821c1eed..3abb7bf04 100644 --- a/tests/interaction/test_coverage.py +++ b/tests/interaction/test_coverage.py @@ -27,6 +27,8 @@ _HARNESS_SELF_TESTS = { "tests.interaction.lowlevel.test_wire.test_recording_read_stream_ends_iteration_when_the_sender_closes", "tests.interaction.transports.test_bridge.test_response_chunks_arrive_as_the_application_sends_them", + "tests.interaction.transports.test_bridge.test_a_second_response_after_the_first_completes_is_invisible_to_the_client", + "tests.interaction.transports.test_bridge.test_body_chunks_after_the_final_chunk_are_ignored", "tests.interaction.transports.test_bridge.test_closing_the_response_delivers_a_disconnect_to_the_application", "tests.interaction.transports.test_bridge.test_an_application_failure_before_the_response_starts_fails_the_request", "tests.interaction.transports.test_bridge.test_disabling_cancel_on_close_lets_the_application_finish_after_disconnect", diff --git a/tests/interaction/transports/_bridge.py b/tests/interaction/transports/_bridge.py index 58274f60c..17c2432ae 100644 --- a/tests/interaction/transports/_bridge.py +++ b/tests/interaction/transports/_bridge.py @@ -12,6 +12,8 @@ - The request body is buffered before the application is invoked (MCP requests are small JSON documents); the response streams chunk by chunk. +- The response ends at the first `http.response.body` whose `more_body` is falsy; anything the + application sends after that is ignored, exactly as a real server's client never observes it. - Closing the response — or the whole client — delivers `http.disconnect` to the application, exactly as a real server sees when its peer goes away. - An exception the application raises before sending `http.response.start` fails the originating @@ -116,6 +118,7 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response: request_delivered = False client_disconnected = anyio.Event() response_started = anyio.Event() + response_complete = False response_status = 0 response_headers: list[tuple[bytes, bytes]] = [] application_error: Exception | None = None @@ -130,7 +133,14 @@ async def receive_request() -> Message: return {"type": "http.disconnect"} async def send_response(message: Message) -> None: - nonlocal response_status, response_headers + nonlocal response_complete, response_status, response_headers + if response_complete: + # The response ended with the final body chunk below; a real server's client never + # observes anything sent after that, so drop it. Starlette's `request_response` + # makes this path real: an endpoint whose sub-application already sent a complete + # rejection response (the legacy SSE transport's request validation) still returns + # a `Response`, which sends a trailing second start/body pair. + return if message["type"] == "http.response.start": response_status = message["status"] response_headers = list(message.get("headers", [])) @@ -141,6 +151,7 @@ async def send_response(message: Message) -> None: if body: await chunk_writer.send(body) if not message.get("more_body", False): + response_complete = True await chunk_writer.aclose() async def run_application() -> None: diff --git a/tests/interaction/transports/test_bridge.py b/tests/interaction/transports/test_bridge.py index 7420b9d90..d51fbd88d 100644 --- a/tests/interaction/transports/test_bridge.py +++ b/tests/interaction/transports/test_bridge.py @@ -40,6 +40,53 @@ async def chunked_app(scope: Scope, receive: Receive, send: Send) -> None: assert chunks == [b"first", b"second"] +async def test_a_second_response_after_the_first_completes_is_invisible_to_the_client() -> None: + """Only the first complete response reaches the client; a trailing start/body pair is dropped. + + Starlette's `request_response` produces exactly this sequence when an endpoint's + sub-application has already sent a complete rejection response (the legacy SSE transport's + request validation): the endpoint still returns a `Response`, which sends a second response. + """ + + async def double_responding_app(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + assert (await receive())["type"] == "http.request" + await send({"type": "http.response.start", "status": 421, "headers": [(b"content-type", b"text/plain")]}) + await send({"type": "http.response.body", "body": b"rejected", "more_body": False}) + await send({"type": "http.response.start", "status": 200, "headers": [(b"x-late", b"yes")]}) + await send({"type": "http.response.body", "body": b"too late", "more_body": False}) + + transport = StreamingASGITransport(double_responding_app) + async with httpx.AsyncClient(transport=transport, base_url="http://bridge") as http: + response = await http.get("/double") + + assert response.status_code == 421 + assert response.text == "rejected" + assert "x-late" not in response.headers + + +async def test_body_chunks_after_the_final_chunk_are_ignored() -> None: + """Extra body chunks after `more_body: False` neither reach the client nor fail the application.""" + application_finished = anyio.Event() + + async def overflowing_app(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + assert (await receive())["type"] == "http.request" + await send({"type": "http.response.start", "status": 200, "headers": []}) + await send({"type": "http.response.body", "body": b"complete", "more_body": False}) + await send({"type": "http.response.body", "body": b"overflow", "more_body": True}) + application_finished.set() + + transport = StreamingASGITransport(overflowing_app) + async with httpx.AsyncClient(transport=transport, base_url="http://bridge") as http: + response = await http.get("/overflow") + with anyio.fail_after(5): + await application_finished.wait() + + assert response.status_code == 200 + assert response.text == "complete" + + async def test_closing_the_response_delivers_a_disconnect_to_the_application() -> None: """A client that closes the response early is seen by the application as an http.disconnect.""" seen_after_request: list[Message] = [] From b732867ccfabb20fd870a173831934930498c268 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 14:20:46 +0000 Subject: [PATCH 13/17] Verify SSE request-context isolation inside each connection block Statements after the connection loop sit in the Python 3.11 trace-loss shadow (python/cpython#106749) and were reported uncovered on 3.11 matrix cells; verifying inside the traced region removes the gap without a coverage exclusion. --- tests/shared/test_sse.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 56d03f530..fcdb11bf9 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -371,9 +371,11 @@ async def test_request_context_propagation() -> None: async def test_request_context_isolation() -> None: """Each SSE connection's handlers see only that connection's request headers.""" factory = in_process_client_factory(make_context_server_app()) - contexts: list[dict[str, Any]] = [] - # Connect three clients in turn, each with its own headers. + # Connect three clients in turn, each with its own headers. Each connection is + # verified inside its own block: on Python 3.11 the line tracer is lost once an + # async-with teardown throws (python/cpython#106749), so statements placed after + # this loop would be reported uncovered on some matrix cells. for i in range(3): headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"} @@ -386,13 +388,10 @@ async def test_request_context_isolation() -> None: assert len(tool_result.content) == 1 content = tool_result.content[0] assert isinstance(content, TextContent) - contexts.append(json.loads(content.text)) - - assert len(contexts) == 3 - for i, ctx in enumerate(contexts): - assert ctx["request_id"] == f"request-{i}" - assert ctx["headers"].get("x-request-id") == f"request-{i}" - assert ctx["headers"].get("x-custom-value") == f"value-{i}" + ctx = json.loads(content.text) + assert ctx["request_id"] == f"request-{i}" + assert ctx["headers"].get("x-request-id") == f"request-{i}" + assert ctx["headers"].get("x-custom-value") == f"value-{i}" def test_sse_message_id_coercion() -> None: From 3e53636d7608b3e685dd3a5798a3393031c94436 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 14:36:33 +0000 Subject: [PATCH 14/17] Exclude the connection loop's exit arc from branch coverage The arc fires after the final async-with teardown, inside the same Python 3.11 trace-loss shadow as the previous commit; only the arc is excluded, the loop body stays measured. --- tests/shared/test_sse.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index fcdb11bf9..856606488 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -375,8 +375,10 @@ async def test_request_context_isolation() -> None: # Connect three clients in turn, each with its own headers. Each connection is # verified inside its own block: on Python 3.11 the line tracer is lost once an # async-with teardown throws (python/cpython#106749), so statements placed after - # this loop would be reported uncovered on some matrix cells. - for i in range(3): + # this loop would be reported uncovered on some matrix cells. The loop's exit + # arc fires after the final teardown and sits in the same shadow, hence the + # branch exclusion. + for i in range(3): # pragma: no branch headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"} async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory, headers=headers) as streams: From 16ceeea3a98a104634e2e22c7afaf74ff312727d Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 15:00:48 +0000 Subject: [PATCH 15/17] Exclude post-teardown header checks from 3.11 coverage measurement The checks must stay after the streamable_http_client context exits so a teardown-time mutation cannot escape them, which on Python 3.11 places them in the trace-loss shadow of python/cpython#106749: they execute and assert on every matrix cell but go unmeasured on 3.11. Same convention as the suite's existing lax exclusions. --- tests/shared/test_streamable_http.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 832cbd13e..61d779324 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -2073,16 +2073,21 @@ async def test_streamable_http_client_does_not_mutate_provided_client(basic_app: result = await session.initialize() assert isinstance(result, InitializeResult) - # Verify client headers were not mutated with MCP protocol headers + # Verify client headers were not mutated with MCP protocol headers. + # These checks deliberately sit after the streamable_http_client context + # exits (a teardown-time mutation would otherwise escape them), which on + # Python 3.11 places them in the post-teardown trace-loss shadow + # (python/cpython#106749): they run and assert on every leg but go + # unmeasured on 3.11 cells, hence the lax exclusions. # If accept header exists, it should still be httpx default, not MCP's - if "accept" in custom_client.headers: # pragma: no branch + if "accept" in custom_client.headers: # pragma: lax no cover assert custom_client.headers.get("accept") == "*/*" # MCP content-type should not have been added - assert custom_client.headers.get("content-type") != "application/json" + assert custom_client.headers.get("content-type") != "application/json" # pragma: lax no cover # Verify custom headers are still present and unchanged - assert custom_client.headers.get("X-Custom-Header") == "custom-value" - assert custom_client.headers.get("Authorization") == "Bearer test-token" + assert custom_client.headers.get("X-Custom-Header") == "custom-value" # pragma: lax no cover + assert custom_client.headers.get("Authorization") == "Bearer test-token" # pragma: lax no cover @pytest.mark.anyio From e155e1bb9e304f6b77078e5851884e1ca8be6cda Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 15:27:35 +0000 Subject: [PATCH 16/17] Deflake child process cleanup tests with polling instead of fixed sleeps The TestChildProcessCleanup tests asserted that spawned writers had started after fixed sleeps (0.5s startup, 0.3s observation window). On loaded CI runners the child interpreter can take longer than that to boot, so the "child should be writing" assertion failed with "assert 0 > 0" before the cleanup logic under test ever ran. Replace the fixed sleeps with bounded polling: - _wait_for_first_write polls until a marker file has grown, proving the writer reached its write loop, with a 15s timeout. - _wait_for_writes_to_stop polls until two samples taken 0.3s apart (3x the writers' 0.1s write interval) observe the same size; if the file never stops growing the timeout fails the test, so a genuine cleanup failure is still reported. Also terminate the spawned process tree in each test's finally block, so a failed assertion can no longer leak a running process tree, and collect garbage before leaving each test so subprocess transports are finalized while the test's ResourceWarning filters are still active. Removing the unconditional sleeps also makes the tests faster. --- tests/client/test_stdio.py | 354 ++++++++++++++++++------------------- 1 file changed, 176 insertions(+), 178 deletions(-) diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index ba58da732..42be287a4 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -1,4 +1,5 @@ import errno +import gc import os import shutil import sys @@ -10,7 +11,12 @@ import pytest from mcp.client.session import ClientSession -from mcp.client.stdio import StdioServerParameters, _create_platform_compatible_process, stdio_client +from mcp.client.stdio import ( + StdioServerParameters, + _create_platform_compatible_process, + _terminate_process_tree, + stdio_client, +) from mcp.shared.exceptions import McpError from mcp.shared.message import SessionMessage from mcp.types import CONNECTION_CLOSED, JSONRPCMessage, JSONRPCRequest, JSONRPCResponse @@ -219,6 +225,36 @@ def sigint_handler(signum, frame): raise +async def _wait_for_first_write(path: str) -> None: + """Poll until the file at *path* exists and has grown beyond its initial empty state. + + The marker files below are created empty before the writer is spawned, so any + growth proves the writing process booted and reached its write loop. Polling + replaces fixed startup sleeps, which flake on loaded machines where interpreter + startup can exceed any fixed window. Bounded so a writer that never starts + fails the test instead of hanging it. + """ + with anyio.fail_after(15): + while not os.path.exists(path) or os.path.getsize(path) == 0: + await anyio.sleep(0.05) + + +async def _wait_for_writes_to_stop(path: str) -> None: + """Poll until the file at *path* stops growing. + + Returns once two consecutive samples taken 0.3 seconds apart (three times the + writers' 0.1 second write interval) observe the same size. The sentinel forces + at least one full sampling interval before the first comparison. If the file + never stops growing, the timeout fails the test: a writer that survives + _terminate_process_tree is a genuine cleanup failure that must not be masked. + """ + last_size = -1 + with anyio.fail_after(15): + while os.path.getsize(path) != last_size: + last_size = os.path.getsize(path) + await anyio.sleep(0.3) + + class TestChildProcessCleanup: """ Tests for child process cleanup functionality using _terminate_process_tree. @@ -259,84 +295,66 @@ async def test_basic_child_process_cleanup(self): with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: parent_marker = f.name - try: - # Parent script that spawns a child process - parent_script = textwrap.dedent( - f""" - import subprocess - import sys - import time - import os - - # Mark that parent started - with open({escape_path_for_python(parent_marker)}, 'w') as f: - f.write('parent started\\n') - - # Child script that writes continuously - child_script = f''' - import time - with open({escape_path_for_python(marker_file)}, 'a') as f: - while True: - f.write(f"{time.time()}") - f.flush() - time.sleep(0.1) - ''' - - # Start the child process - child = subprocess.Popen([sys.executable, '-c', child_script]) - - # Parent just sleeps + # Parent script that spawns a child process + parent_script = textwrap.dedent( + f""" + import subprocess + import sys + import time + import os + + # Mark that parent started + with open({escape_path_for_python(parent_marker)}, 'w') as f: + f.write('parent started\\n') + + # Child script that writes continuously + child_script = f''' + import time + with open({escape_path_for_python(marker_file)}, 'a') as f: while True: + f.write(f"{time.time()}") + f.flush() time.sleep(0.1) - """ - ) + ''' - print("\nStarting child process termination test...") + # Start the child process + child = subprocess.Popen([sys.executable, '-c', child_script]) - # Start the parent process - proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) + # Parent just sleeps + while True: + time.sleep(0.1) + """ + ) - # Wait for processes to start - await anyio.sleep(0.5) + # Start the parent process + proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) - # Verify parent started - assert os.path.exists(parent_marker), "Parent process didn't start" + try: + # Wait for the parent to start and the child to reach its write loop + await _wait_for_first_write(parent_marker) + assert os.path.getsize(parent_marker) > 0, "Parent process didn't start" - # Verify child is writing - if os.path.exists(marker_file): # pragma: no branch - initial_size = os.path.getsize(marker_file) - await anyio.sleep(0.3) - size_after_wait = os.path.getsize(marker_file) - assert size_after_wait > initial_size, "Child process should be writing" - print(f"Child is writing (file grew from {initial_size} to {size_after_wait} bytes)") + await _wait_for_first_write(marker_file) + assert os.path.getsize(marker_file) > 0, "Child process should be writing" # Terminate using our function - print("Terminating process and children...") - from mcp.client.stdio import _terminate_process_tree - await _terminate_process_tree(proc) - # Verify processes stopped - await anyio.sleep(0.5) - if os.path.exists(marker_file): # pragma: no branch - size_after_cleanup = os.path.getsize(marker_file) - await anyio.sleep(0.5) - final_size = os.path.getsize(marker_file) - - print(f"After cleanup: file size {size_after_cleanup} -> {final_size}") - assert final_size == size_after_cleanup, ( - f"Child process still running! File grew by {final_size - size_after_cleanup} bytes" - ) - - print("SUCCESS: Child process was properly terminated") - + # Verify the child stopped writing; a survivor times out and fails the test + await _wait_for_writes_to_stop(marker_file) finally: + # Terminate again so no failure above can leak the spawned tree + # (safe: _terminate_process_tree tolerates an already-dead tree) + await _terminate_process_tree(proc) # Clean up files for f in [marker_file, parent_marker]: try: os.unlink(f) except OSError: # pragma: no cover pass + # Collect subprocess transports now, while this test's warning filters + # are active, so GC-time ResourceWarnings cannot hit a later test + gc.collect() @pytest.mark.anyio @pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default") @@ -353,88 +371,78 @@ async def test_nested_process_tree(self): with tempfile.NamedTemporaryFile(mode="w", delete=False) as f3: grandchild_file = f3.name - try: - # Simple nested process tree test - # We create parent -> child -> grandchild, each writing to a file - parent_script = textwrap.dedent( - f""" - import subprocess - import sys - import time - import os - - # Child will spawn grandchild and write to child file - child_script = f'''import subprocess - import sys - import time - - # Grandchild just writes to file - grandchild_script = \"\"\"import time - with open({escape_path_for_python(grandchild_file)}, 'a') as f: - while True: - f.write(f"gc {{time.time()}}") - f.flush() - time.sleep(0.1)\"\"\" - - # Spawn grandchild - subprocess.Popen([sys.executable, '-c', grandchild_script]) - - # Child writes to its file - with open({escape_path_for_python(child_file)}, 'a') as f: - while True: - f.write(f"c {time.time()}") - f.flush() - time.sleep(0.1)''' - - # Spawn child process - subprocess.Popen([sys.executable, '-c', child_script]) - - # Parent writes to its file - with open({escape_path_for_python(parent_file)}, 'a') as f: - while True: - f.write(f"p {time.time()}") - f.flush() - time.sleep(0.1) - """ - ) + # Simple nested process tree test + # We create parent -> child -> grandchild, each writing to a file + parent_script = textwrap.dedent( + f""" + import subprocess + import sys + import time + import os + + # Child will spawn grandchild and write to child file + child_script = f'''import subprocess + import sys + import time + + # Grandchild just writes to file + grandchild_script = \"\"\"import time + with open({escape_path_for_python(grandchild_file)}, 'a') as f: + while True: + f.write(f"gc {{time.time()}}") + f.flush() + time.sleep(0.1)\"\"\" - # Start the parent process - proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) + # Spawn grandchild + subprocess.Popen([sys.executable, '-c', grandchild_script]) - # Let all processes start - await anyio.sleep(1.0) + # Child writes to its file + with open({escape_path_for_python(child_file)}, 'a') as f: + while True: + f.write(f"c {time.time()}") + f.flush() + time.sleep(0.1)''' - # Verify all are writing - for file_path, name in [(parent_file, "parent"), (child_file, "child"), (grandchild_file, "grandchild")]: - if os.path.exists(file_path): # pragma: no branch - initial_size = os.path.getsize(file_path) - await anyio.sleep(0.3) - new_size = os.path.getsize(file_path) - assert new_size > initial_size, f"{name} process should be writing" + # Spawn child process + subprocess.Popen([sys.executable, '-c', child_script]) - # Terminate the whole tree - from mcp.client.stdio import _terminate_process_tree + # Parent writes to its file + with open({escape_path_for_python(parent_file)}, 'a') as f: + while True: + f.write(f"p {time.time()}") + f.flush() + time.sleep(0.1) + """ + ) - await _terminate_process_tree(proc) + # Start the parent process + proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) - # Verify all stopped - await anyio.sleep(0.5) + try: + # Wait for every level of the tree to reach its write loop for file_path, name in [(parent_file, "parent"), (child_file, "child"), (grandchild_file, "grandchild")]: - if os.path.exists(file_path): # pragma: no branch - size1 = os.path.getsize(file_path) - await anyio.sleep(0.3) - size2 = os.path.getsize(file_path) - assert size1 == size2, f"{name} still writing after cleanup!" + await _wait_for_first_write(file_path) + assert os.path.getsize(file_path) > 0, f"{name} process should be writing" - print("SUCCESS: All processes in tree terminated") + # Terminate the whole tree + await _terminate_process_tree(proc) + # Verify every level stopped writing; a survivor times out and fails the test + for file_path in (parent_file, child_file, grandchild_file): + await _wait_for_writes_to_stop(file_path) finally: + # Terminate again so no failure above can leak the spawned tree + # (safe: _terminate_process_tree tolerates an already-dead tree) + await _terminate_process_tree(proc) # Clean up all marker files for f in [parent_file, child_file, grandchild_file]: try: os.unlink(f) except OSError: # pragma: no cover pass + # Collect subprocess transports now, while this test's warning filters + # are active, so GC-time ResourceWarnings cannot hit a later test + gc.collect() @pytest.mark.anyio @pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default") @@ -448,72 +456,62 @@ async def test_early_parent_exit(self): with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: marker_file = f.name - try: - # Parent that spawns child and waits briefly - parent_script = textwrap.dedent( - f""" - import subprocess - import sys - import time - import signal - - # Child that continues running - child_script = f'''import time - with open({escape_path_for_python(marker_file)}, 'a') as f: - while True: - f.write(f"child {time.time()}") - f.flush() - time.sleep(0.1)''' - - # Start child in same process group - subprocess.Popen([sys.executable, '-c', child_script]) - - # Parent waits a bit then exits on SIGTERM - def handle_term(sig, frame): - sys.exit(0) - - signal.signal(signal.SIGTERM, handle_term) - - # Wait + # Parent that spawns child and waits briefly + parent_script = textwrap.dedent( + f""" + import subprocess + import sys + import time + import signal + + # Child that continues running + child_script = f'''import time + with open({escape_path_for_python(marker_file)}, 'a') as f: while True: - time.sleep(0.1) - """ - ) + f.write(f"child {time.time()}") + f.flush() + time.sleep(0.1)''' - # Start the parent process - proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) + # Start child in same process group + subprocess.Popen([sys.executable, '-c', child_script]) - # Let child start writing - await anyio.sleep(0.5) + # Parent waits a bit then exits on SIGTERM + def handle_term(sig, frame): + sys.exit(0) - # Verify child is writing - if os.path.exists(marker_file): # pragma: no cover - size1 = os.path.getsize(marker_file) - await anyio.sleep(0.3) - size2 = os.path.getsize(marker_file) - assert size2 > size1, "Child should be writing" + signal.signal(signal.SIGTERM, handle_term) - # Terminate - this will kill the process group even if parent exits first - from mcp.client.stdio import _terminate_process_tree + # Wait + while True: + time.sleep(0.1) + """ + ) - await _terminate_process_tree(proc) + # Start the parent process + proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) - # Verify child stopped - await anyio.sleep(0.5) - if os.path.exists(marker_file): # pragma: no branch - size3 = os.path.getsize(marker_file) - await anyio.sleep(0.3) - size4 = os.path.getsize(marker_file) - assert size3 == size4, "Child should be terminated" + try: + # Wait for the child to reach its write loop + await _wait_for_first_write(marker_file) + assert os.path.getsize(marker_file) > 0, "Child should be writing" - print("SUCCESS: Child terminated even with parent exit during cleanup") + # Terminate - this will kill the process group even if parent exits first + await _terminate_process_tree(proc) + # Verify the child stopped writing; a survivor times out and fails the test + await _wait_for_writes_to_stop(marker_file) finally: + # Terminate again so no failure above can leak the spawned tree + # (safe: _terminate_process_tree tolerates an already-dead tree) + await _terminate_process_tree(proc) # Clean up marker file try: os.unlink(marker_file) except OSError: # pragma: no cover pass + # Collect subprocess transports now, while this test's warning filters + # are active, so GC-time ResourceWarnings cannot hit a later test + gc.collect() @pytest.mark.anyio From 4a17006005d42aa1eb5bccb345bd92d0bee49f76 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 15:40:01 +0000 Subject: [PATCH 17/17] Harden stop detection and avoid redundant tree termination in cleanup tests Two follow-up fixes to the child process cleanup tests: Require three consecutive stable samples before declaring writes stopped. The previous check exited on the first pair of samples 0.3s apart with no growth, retrying within the 15s budget, which made it easier for a CPU-starved (but alive) writer to be mistaken for a terminated one. The counter resets on any observed growth, and a file that never stops growing still fails the test via the timeout. Only re-terminate the process tree in the finally block if the test failed before reaching its own _terminate_process_tree call. The unconditional second call ran termination against an already-closed job object handle on Windows and logged a spurious fallback warning on POSIX in every passing run. The skipped branch only executes on failing runs, so it is excluded from coverage. --- tests/client/test_stdio.py | 41 +++++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 42be287a4..4a93f998b 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -242,16 +242,26 @@ async def _wait_for_first_write(path: str) -> None: async def _wait_for_writes_to_stop(path: str) -> None: """Poll until the file at *path* stops growing. - Returns once two consecutive samples taken 0.3 seconds apart (three times the - writers' 0.1 second write interval) observe the same size. The sentinel forces - at least one full sampling interval before the first comparison. If the file + Returns once the size is unchanged across three successive 0.3 second gaps + (each three times the writers' 0.1 second write interval), so a writer that + is merely starved of CPU for a single gap is not mistaken for a terminated + one. Any observed growth resets the consecutive-stable counter. The sentinel + forces at least one non-stable iteration before counting starts. If the file never stops growing, the timeout fails the test: a writer that survives _terminate_process_tree is a genuine cleanup failure that must not be masked. """ last_size = -1 + stable_pairs = 0 with anyio.fail_after(15): - while os.path.getsize(path) != last_size: - last_size = os.path.getsize(path) + while True: + current_size = os.path.getsize(path) + if current_size == last_size: + stable_pairs += 1 + else: + stable_pairs = 0 + last_size = current_size + if stable_pairs == 3: + return await anyio.sleep(0.3) @@ -328,6 +338,7 @@ async def test_basic_child_process_cleanup(self): # Start the parent process proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) + tree_killed = False try: # Wait for the parent to start and the child to reach its write loop @@ -339,13 +350,13 @@ async def test_basic_child_process_cleanup(self): # Terminate using our function await _terminate_process_tree(proc) + tree_killed = True # Verify the child stopped writing; a survivor times out and fails the test await _wait_for_writes_to_stop(marker_file) finally: - # Terminate again so no failure above can leak the spawned tree - # (safe: _terminate_process_tree tolerates an already-dead tree) - await _terminate_process_tree(proc) + if not tree_killed: # pragma: no cover - cleanup only reached when the test failed mid-flight + await _terminate_process_tree(proc) # Clean up files for f in [marker_file, parent_marker]: try: @@ -417,6 +428,7 @@ async def test_nested_process_tree(self): # Start the parent process proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) + tree_killed = False try: # Wait for every level of the tree to reach its write loop @@ -426,14 +438,14 @@ async def test_nested_process_tree(self): # Terminate the whole tree await _terminate_process_tree(proc) + tree_killed = True # Verify every level stopped writing; a survivor times out and fails the test for file_path in (parent_file, child_file, grandchild_file): await _wait_for_writes_to_stop(file_path) finally: - # Terminate again so no failure above can leak the spawned tree - # (safe: _terminate_process_tree tolerates an already-dead tree) - await _terminate_process_tree(proc) + if not tree_killed: # pragma: no cover - cleanup only reached when the test failed mid-flight + await _terminate_process_tree(proc) # Clean up all marker files for f in [parent_file, child_file, grandchild_file]: try: @@ -489,6 +501,7 @@ def handle_term(sig, frame): # Start the parent process proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) + tree_killed = False try: # Wait for the child to reach its write loop @@ -497,13 +510,13 @@ def handle_term(sig, frame): # Terminate - this will kill the process group even if parent exits first await _terminate_process_tree(proc) + tree_killed = True # Verify the child stopped writing; a survivor times out and fails the test await _wait_for_writes_to_stop(marker_file) finally: - # Terminate again so no failure above can leak the spawned tree - # (safe: _terminate_process_tree tolerates an already-dead tree) - await _terminate_process_tree(proc) + if not tree_killed: # pragma: no cover - cleanup only reached when the test failed mid-flight + await _terminate_process_tree(proc) # Clean up marker file try: os.unlink(marker_file)