diff --git a/pyproject.toml b/pyproject.toml index 9b617f667..ecbf4c6d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,12 +24,13 @@ classifiers = [ dependencies = [ "anyio>=4.5", "httpx>=0.27", - "httpx-sse>=0.4", + "httpx-sse>=0.4.1", "pydantic>=2.7.2,<3.0.0", "starlette>=0.27", "python-multipart>=0.0.9", "sse-starlette>=1.6.1", "pydantic-settings>=2.5.2", + "typing_extensions>=4.12", "uvicorn>=0.23.1; sys_platform != 'emscripten'", "jsonschema>=4.20.0", ] @@ -49,10 +50,10 @@ required-version = ">=0.7.2" [dependency-groups] dev = [ + "anyio[trio]", "pyright>=1.1.391", "pytest>=8.3.4", "ruff>=0.8.5", - "trio>=0.26.2", "pytest-flakefinder>=1.1.0", "pytest-xdist>=3.6.1", "pytest-examples>=0.0.14", @@ -123,5 +124,5 @@ filterwarnings = [ # This should be fixed on Uvicorn's side. "ignore::DeprecationWarning:websockets", "ignore:websockets.server.WebSocketServerProtocol is deprecated:DeprecationWarning", - "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel" + "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel", ] diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 39ac34d8a..bfb0f0aa1 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -8,7 +8,7 @@ import logging from collections.abc import AsyncGenerator, Awaitable, Callable -from contextlib import asynccontextmanager +from contextlib import aclosing, asynccontextmanager from dataclasses import dataclass from datetime import timedelta @@ -240,15 +240,16 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: event_source.response.raise_for_status() logger.debug("Resumption GET SSE connection established") - async for sse in event_source.aiter_sse(): - is_complete = await self._handle_sse_event( - sse, - ctx.read_stream_writer, - original_request_id, - ctx.metadata.on_resumption_token_update if ctx.metadata else None, - ) - if is_complete: - break + async with aclosing(event_source.aiter_sse()) as iterator: + async for sse in iterator: + is_complete = await self._handle_sse_event( + sse, + ctx.read_stream_writer, + original_request_id, + ctx.metadata.on_resumption_token_update if ctx.metadata else None, + ) + if is_complete: + break async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" @@ -319,18 +320,18 @@ async def _handle_sse_response( ) -> None: """Handle SSE response from the server.""" try: - event_source = EventSource(response) - async for sse in event_source.aiter_sse(): - is_complete = await self._handle_sse_event( - sse, - ctx.read_stream_writer, - resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), - is_initialization=is_initialization, - ) - # If the SSE event indicates completion, like returning respose/error - # break the loop - if is_complete: - break + async with aclosing(EventSource(response).aiter_sse()) as items: + async for sse in items: + is_complete = await self._handle_sse_event( + sse, + ctx.read_stream_writer, + resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), + is_initialization=is_initialization, + ) + # If the SSE event indicates completion, like returning respose/error + # break the loop + if is_complete: + break except Exception as e: logger.exception("Error reading SSE stream:") await ctx.read_stream_writer.send(e) @@ -471,15 +472,14 @@ async def streamablehttp_client( read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) - async with anyio.create_task_group() as tg: - try: - logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") + try: + logger.info(f"Connecting to StreamableHTTP endpoint: {url}") - async with httpx_client_factory( - headers=transport.request_headers, - timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), - auth=transport.auth, - ) as client: + async with create_mcp_http_client( + headers=transport.request_headers, + timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), + ) as client: + async with anyio.create_task_group() as tg: # Define callbacks that need access to tg def start_get_stream() -> None: tg.start_soon(transport.handle_get_stream, client, read_stream_writer) @@ -504,6 +504,6 @@ def start_get_stream() -> None: if transport.session_id and terminate_on_close: await transport.terminate_session(client) tg.cancel_scope.cancel() - finally: - await read_stream_writer.aclose() - await write_stream.aclose() + finally: + await read_stream_writer.aclose() + await write_stream.aclose() diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 5c696b136..7e23bd5d0 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -44,6 +44,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl +from typing_extensions import Self import mcp.types as types from mcp.server.models import InitializationOptions @@ -93,10 +94,14 @@ def __init__( ) self._init_options = init_options + + async def __aenter__(self) -> Self: + await super().__aenter__() self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[ ServerRequestResponder ](0) - self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose()) + self._exit_stack.push_async_callback(self._incoming_message_stream_reader.aclose) + return self @property def client_params(self) -> types.InitializeRequestParams | None: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 6536272d9..7db5d4c1c 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -7,6 +7,7 @@ import anyio import httpx +from anyio.abc import TaskGroup from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import BaseModel from typing_extensions import Self @@ -177,6 +178,8 @@ class BaseSession( _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _progress_callbacks: dict[RequestId, ProgressFnT] + _exit_stack: AsyncExitStack + _task_group: TaskGroup def __init__( self, @@ -196,12 +199,17 @@ def __init__( self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._progress_callbacks = {} - self._exit_stack = AsyncExitStack() async def __aenter__(self) -> Self: - self._task_group = anyio.create_task_group() - await self._task_group.__aenter__() - self._task_group.start_soon(self._receive_loop) + async with AsyncExitStack() as exit_stack: + self._task_group = await exit_stack.enter_async_context(anyio.create_task_group()) + self._task_group.start_soon(self._receive_loop) + # Using BaseSession as a context manager should not block on exit (this + # would be very surprising behavior), so make sure to cancel the tasks + # in the task group. + exit_stack.callback(self._task_group.cancel_scope.cancel) + self._exit_stack = exit_stack.pop_all() + return self async def __aexit__( @@ -210,12 +218,7 @@ async def __aexit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> bool | None: - await self._exit_stack.aclose() - # Using BaseSession as a context manager should not block on exit (this - # would be very surprising behavior), so make sure to cancel the tasks - # in the task group. - self._task_group.cancel_scope.cancel() - return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + return await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) async def send_request( self, diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 327d1a9e4..12c043fe7 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -334,15 +334,15 @@ async def mock_server(): ) async with ( + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, ClientSession( server_to_client_receive, client_to_server_send, ) as session, anyio.create_task_group() as tg, - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, ): tg.start_soon(mock_server) diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index af7e47993..000000000 --- a/tests/conftest.py +++ /dev/null @@ -1,6 +0,0 @@ -import pytest - - -@pytest.fixture -def anyio_backend(): - return "asyncio" diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 43af35061..4982e6fe4 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -4,6 +4,7 @@ import multiprocessing import socket import time +from contextlib import contextmanager import httpx import pytest @@ -66,29 +67,40 @@ async def handle_sse(request: Request): uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") +@contextmanager def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) + context = multiprocessing.get_context("spawn") + process = context.Process(target=run_server_with_settings, args=(port, security_settings)) process.start() - # Give server time to start - time.sleep(1) - return process + + # Wait until the designated port can be connected + max_attempts = 20 + for attempt in range(max_attempts): + try: + with socket.create_connection(("127.0.0.1", port)): + break + except ConnectionRefusedError: + time.sleep(0.1) + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + + try: + yield + finally: + process.terminate() + process.join() @pytest.mark.anyio async def test_sse_security_default_settings(server_port: int): """Test SSE with default security settings (protection disabled).""" - process = start_server_process(server_port) - - try: + with start_server_process(server_port): headers = {"Host": "evil.com", "Origin": "http://evil.com"} async with httpx.AsyncClient(timeout=5.0) as client: async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: assert response.status_code == 200 - finally: - process.terminate() - process.join() @pytest.mark.anyio @@ -96,9 +108,7 @@ async def test_sse_security_invalid_host_header(server_port: int): """Test SSE with invalid Host header.""" # Enable security by providing settings with an empty allowed_hosts list security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"]) - process = start_server_process(server_port, security_settings) - - try: + with start_server_process(server_port, security_settings): # Test with invalid host header headers = {"Host": "evil.com"} @@ -107,10 +117,6 @@ async def test_sse_security_invalid_host_header(server_port: int): assert response.status_code == 421 assert response.text == "Invalid Host header" - finally: - process.terminate() - process.join() - @pytest.mark.anyio async def test_sse_security_invalid_origin_header(server_port: int): @@ -119,9 +125,7 @@ async def test_sse_security_invalid_origin_header(server_port: int): security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://localhost:*"] ) - process = start_server_process(server_port, security_settings) - - try: + with start_server_process(server_port, security_settings): # Test with invalid origin header headers = {"Origin": "http://evil.com"} @@ -130,10 +134,6 @@ async def test_sse_security_invalid_origin_header(server_port: int): assert response.status_code == 400 assert response.text == "Invalid Origin header" - finally: - process.terminate() - process.join() - @pytest.mark.anyio async def test_sse_security_post_invalid_content_type(server_port: int): @@ -142,9 +142,7 @@ async def test_sse_security_post_invalid_content_type(server_port: int): security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) - - try: + with start_server_process(server_port, security_settings): async with httpx.AsyncClient(timeout=5.0) as client: # Test POST with invalid content type fake_session_id = "12345678123456781234567812345678" @@ -163,18 +161,12 @@ async def test_sse_security_post_invalid_content_type(server_port: int): assert response.status_code == 400 assert response.text == "Invalid Content-Type header" - finally: - process.terminate() - process.join() - @pytest.mark.anyio async def test_sse_security_disabled(server_port: int): """Test SSE with security disabled.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) - process = start_server_process(server_port, settings) - - try: + with start_server_process(server_port, settings): # Test with invalid host header - should still work headers = {"Host": "evil.com"} @@ -184,10 +176,6 @@ async def test_sse_security_disabled(server_port: int): # Should connect successfully even with invalid host assert response.status_code == 200 - finally: - process.terminate() - process.join() - @pytest.mark.anyio async def test_sse_security_custom_allowed_hosts(server_port: int): @@ -197,9 +185,7 @@ async def test_sse_security_custom_allowed_hosts(server_port: int): allowed_hosts=["localhost", "127.0.0.1", "custom.host"], allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], ) - process = start_server_process(server_port, settings) - - try: + with start_server_process(server_port, settings): # Test with custom allowed host headers = {"Host": "custom.host"} @@ -217,10 +203,6 @@ async def test_sse_security_custom_allowed_hosts(server_port: int): assert response.status_code == 421 assert response.text == "Invalid Host header" - finally: - process.terminate() - process.join() - @pytest.mark.anyio async def test_sse_security_wildcard_ports(server_port: int): @@ -230,9 +212,7 @@ async def test_sse_security_wildcard_ports(server_port: int): allowed_hosts=["localhost:*", "127.0.0.1:*"], allowed_origins=["http://localhost:*", "http://127.0.0.1:*"], ) - process = start_server_process(server_port, settings) - - try: + with start_server_process(server_port, settings): # Test with various port numbers for test_port in [8080, 3000, 9999]: headers = {"Host": f"localhost:{test_port}"} @@ -251,10 +231,6 @@ async def test_sse_security_wildcard_ports(server_port: int): # Should connect successfully with any port assert response.status_code == 200 - finally: - process.terminate() - process.join() - @pytest.mark.anyio async def test_sse_security_post_valid_content_type(server_port: int): @@ -263,9 +239,7 @@ async def test_sse_security_post_valid_content_type(server_port: int): security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) - - try: + with start_server_process(server_port, security_settings): async with httpx.AsyncClient() as client: # Test with various valid content types valid_content_types = [ @@ -287,7 +261,3 @@ async def test_sse_security_post_valid_content_type(server_port: int): # We're testing that it passes the content-type check assert response.status_code == 404 assert response.text == "Could not find session" - - finally: - process.terminate() - process.join() diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index eed791924..69e08efe1 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -75,7 +75,8 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) + context = multiprocessing.get_context("spawn") + process = context.Process(target=run_server_with_settings, args=(port, security_settings)) process.start() # Give server time to start time.sleep(1) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 8e1912e9b..0d5640da8 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -3,6 +3,7 @@ import socket import time from collections.abc import AsyncGenerator, Generator +from contextlib import aclosing import anyio import httpx @@ -165,14 +166,15 @@ async def connection_test() -> None: assert response.headers["content-type"] == "text/event-stream; charset=utf-8" line_number = 0 - async for line in response.aiter_lines(): - if line_number == 0: - assert line == "event: endpoint" - elif line_number == 1: - assert line.startswith("data: /messages/?session_id=") - else: - return - line_number += 1 + async with aclosing(response.aiter_lines()) as lines: + async for line in lines: + if line_number == 0: + assert line == "event: endpoint" + elif line_number == 1: + assert line.startswith("data: /messages/?session_id=") + else: + return + line_number += 1 # Add timeout to prevent test from hanging if it fails with anyio.fail_after(3): diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 1ffcc13b0..c786283ff 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -100,16 +100,17 @@ async def replay_events_after( """Replay events after the specified ID.""" # Find the index of the last event ID start_index = None - for i, (_, event_id, _) in enumerate(self._events): + stream_id = None + for i, (stream_id_, event_id, _) in enumerate(self._events): if event_id == last_event_id: start_index = i + 1 + stream_id = stream_id_ break if start_index is None: # If event ID not found, start from beginning start_index = 0 - stream_id = None # Replay events for _, event_id, message in self._events[start_index:]: await send_callback(EventMessage(message, event_id)) @@ -1055,7 +1056,8 @@ async def test_streamablehttp_client_resumption(event_server): captured_session_id = None captured_resumption_token = None captured_notifications = [] - tool_started = False + tool_started_event = anyio.Event() + session_resumption_token_received_event = anyio.Event() captured_protocol_version = None async def message_handler( @@ -1066,12 +1068,12 @@ async def message_handler( # Look for our special notification that indicates the tool is running if isinstance(message.root, types.LoggingMessageNotification): if message.root.params.data == "Tool started": - nonlocal tool_started - tool_started = True + tool_started_event.set() async def on_resumption_token_update(token: str) -> None: nonlocal captured_resumption_token captured_resumption_token = token + session_resumption_token_received_event.set() # First, start the client session and begin the long-running tool async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as ( @@ -1110,8 +1112,8 @@ async def run_tool(): # Wait for the tool to start and at least one notification # and then kill the task group - while not tool_started or not captured_resumption_token: - await anyio.sleep(0.1) + await tool_started_event.wait() + await session_resumption_token_received_event.wait() tg.cancel_scope.cancel() # Store pre notifications and clear the captured notifications diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 5081f1d53..084043236 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -27,6 +27,8 @@ SERVER_NAME = "test_server_for_WS" +pytestmark = pytest.mark.parametrize("anyio_backend", ["asyncio"]) + @pytest.fixture def server_port() -> int: diff --git a/uv.lock b/uv.lock index 0e11150fa..0d28dfbf7 100644 --- a/uv.lock +++ b/uv.lock @@ -40,6 +40,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/68/f9e9bf6324c46e6b8396610aef90ad423ec3e18c9079547ceafea3dce0ec/anyio-4.5.0-py3-none-any.whl", hash = "sha256:fdeb095b7cc5a5563175eedd926ec4ae55413bb4be5770c424af0ba46ccb4a78", size = 89250, upload-time = "2024-09-19T09:28:42.699Z" }, ] +[package.optional-dependencies] +trio = [ + { name = "trio" }, +] + [[package]] name = "asttokens" version = "2.4.1" @@ -393,11 +398,11 @@ wheels = [ [[package]] name = "httpx-sse" -version = "0.4.0" +version = "0.4.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/4c/60/8f4281fa9bbf3c8034fd54c0e7412e66edbab6bc74c4996bd616f8d0406e/httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721", size = 12624, upload-time = "2023-12-22T08:01:21.083Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6e/fa/66bd985dd0b7c109a3bcb89272ee0bfb7e2b4d06309ad7b38ff866734b2a/httpx_sse-0.4.1.tar.gz", hash = "sha256:8f44d34414bc7b21bf3602713005c5df4917884f76072479b21f68befa4ea26e", size = 12998, upload-time = "2025-06-24T13:21:05.71Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e1/9b/a181f281f65d776426002f330c31849b86b31fc9d848db62e16f03ff739f/httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f", size = 7819, upload-time = "2023-12-22T08:01:19.89Z" }, + { url = "https://files.pythonhosted.org/packages/25/0a/6269e3473b09aed2dab8aa1a600c70f31f00ae1349bee30658f7e358a159/httpx_sse-0.4.1-py3-none-any.whl", hash = "sha256:cba42174344c3a5b06f255ce65b350880f962d99ead85e776f23c6618a377a37", size = 8054, upload-time = "2025-06-24T13:21:04.772Z" }, ] [[package]] @@ -565,6 +570,7 @@ dependencies = [ { name = "python-multipart" }, { name = "sse-starlette" }, { name = "starlette" }, + { name = "typing-extensions" }, { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, ] @@ -582,6 +588,7 @@ ws = [ [package.dev-dependencies] dev = [ + { name = "anyio", extra = ["trio"] }, { name = "inline-snapshot" }, { name = "pyright" }, { name = "pytest" }, @@ -590,7 +597,6 @@ dev = [ { name = "pytest-pretty" }, { name = "pytest-xdist" }, { name = "ruff" }, - { name = "trio" }, ] docs = [ { name = "mkdocs" }, @@ -603,7 +609,7 @@ docs = [ requires-dist = [ { name = "anyio", specifier = ">=4.5" }, { name = "httpx", specifier = ">=0.27" }, - { name = "httpx-sse", specifier = ">=0.4" }, + { name = "httpx-sse", specifier = ">=0.4.1" }, { name = "jsonschema", specifier = ">=4.20.0" }, { name = "pydantic", specifier = ">=2.7.2,<3.0.0" }, { name = "pydantic-settings", specifier = ">=2.5.2" }, @@ -613,6 +619,7 @@ requires-dist = [ { name = "sse-starlette", specifier = ">=1.6.1" }, { name = "starlette", specifier = ">=0.27" }, { name = "typer", marker = "extra == 'cli'", specifier = ">=0.12.4" }, + { name = "typing-extensions", specifier = ">=4.12" }, { name = "uvicorn", marker = "sys_platform != 'emscripten'", specifier = ">=0.23.1" }, { name = "websockets", marker = "extra == 'ws'", specifier = ">=15.0.1" }, ] @@ -620,6 +627,7 @@ provides-extras = ["cli", "rich", "ws"] [package.metadata.requires-dev] dev = [ + { name = "anyio", extras = ["trio"] }, { name = "inline-snapshot", specifier = ">=0.23.0" }, { name = "pyright", specifier = ">=1.1.391" }, { name = "pytest", specifier = ">=8.3.4" }, @@ -628,7 +636,6 @@ dev = [ { name = "pytest-pretty", specifier = ">=1.2.0" }, { name = "pytest-xdist", specifier = ">=3.6.1" }, { name = "ruff", specifier = ">=0.8.5" }, - { name = "trio", specifier = ">=0.26.2" }, ] docs = [ { name = "mkdocs", specifier = ">=1.6.1" },