From 9327cff3f6823fc5980e4ef63f9867bcf722c074 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Thu, 12 Jun 2025 13:30:39 +0300 Subject: [PATCH 1/7] Improved Trio support * Removed the asyncio-only parametrization of the anyio_backend except for test_ws, as `websockets` doesn't support Trio yet * Try to close async generators explicitly where possible * Changed nesting order for more predictable closing of async resources * Refactored `__aenter__` and `__aexit__` in some cases to exit the task group if there's a problem during initialization * Fixed test failures in client/test_auth.py where an async fixture was used in sync tests * Fixed subtle bug in `SimpleEventStore` where retrieving the stream ID was timing-dependent --- pyproject.toml | 8 +++-- src/mcp/client/streamable_http.py | 48 +++++++++++++++------------- src/mcp/server/session.py | 15 ++++++--- src/mcp/shared/session.py | 25 +++++++++------ tests/client/test_auth.py | 2 +- tests/client/test_session.py | 8 ++--- tests/conftest.py | 6 ---- tests/shared/test_streamable_http.py | 16 ++++++---- tests/shared/test_ws.py | 2 ++ uv.lock | 11 +++++-- 10 files changed, 83 insertions(+), 58 deletions(-) delete mode 100644 tests/conftest.py diff --git a/pyproject.toml b/pyproject.toml index 9ad50ab58..6e3c00d90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "python-multipart>=0.0.9", "sse-starlette>=1.6.1", "pydantic-settings>=2.5.2", + "typing_extensions>=4.12", "uvicorn>=0.23.1; sys_platform != 'emscripten'", ] @@ -48,10 +49,10 @@ required-version = ">=0.7.2" [dependency-groups] dev = [ + "anyio[trio]", "pyright>=1.1.391", "pytest>=8.3.4", "ruff>=0.8.5", - "trio>=0.26.2", "pytest-flakefinder>=1.1.0", "pytest-xdist>=3.6.1", "pytest-examples>=0.0.14", @@ -122,5 +123,8 @@ filterwarnings = [ # This should be fixed on Uvicorn's side. "ignore::DeprecationWarning:websockets", "ignore:websockets.server.WebSocketServerProtocol is deprecated:DeprecationWarning", - "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel" + "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel", + # This is to avoid test failures on Trio due to httpx's failure to explicitly close + # async generators + "ignore::pytest.PytestUnraisableExceptionWarning" ] diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 471870533..15fc3393a 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -8,9 +8,10 @@ import logging from collections.abc import AsyncGenerator, Awaitable, Callable -from contextlib import asynccontextmanager +from contextlib import aclosing, asynccontextmanager from dataclasses import dataclass from datetime import timedelta +from typing import cast import anyio import httpx @@ -284,16 +285,18 @@ async def _handle_sse_response(self, response: httpx.Response, ctx: RequestConte """Handle SSE response from the server.""" try: event_source = EventSource(response) - async for sse in event_source.aiter_sse(): - is_complete = await self._handle_sse_event( - sse, - ctx.read_stream_writer, - resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), - ) - # If the SSE event indicates completion, like returning respose/error - # break the loop - if is_complete: - break + sse_iter = cast(AsyncGenerator[ServerSentEvent], event_source.aiter_sse()) + async with aclosing(sse_iter) as items: + async for sse in items: + is_complete = await self._handle_sse_event( + sse, + ctx.read_stream_writer, + resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), + ) + # If the SSE event indicates completion, like returning respose/error + # break the loop + if is_complete: + break except Exception as e: logger.exception("Error reading SSE stream:") await ctx.read_stream_writer.send(e) @@ -434,15 +437,16 @@ async def streamablehttp_client( read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) - async with anyio.create_task_group() as tg: - try: - logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") + try: + logger.info(f"Connecting to StreamableHTTP endpoint: {url}") - async with httpx_client_factory( - headers=transport.request_headers, - timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), - auth=transport.auth, - ) as client: + async with create_mcp_http_client( + headers=transport.request_headers, + timeout=httpx.Timeout( + transport.timeout, read=transport.sse_read_timeout + ), + ) as client: + async with anyio.create_task_group() as tg: # Define callbacks that need access to tg def start_get_stream() -> None: tg.start_soon(transport.handle_get_stream, client, read_stream_writer) @@ -467,6 +471,6 @@ def start_get_stream() -> None: if transport.session_id and terminate_on_close: await transport.terminate_session(client) tg.cancel_scope.cancel() - finally: - await read_stream_writer.aclose() - await write_stream.aclose() + finally: + await read_stream_writer.aclose() + await write_stream.aclose() diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index e6611b0d4..df1dd93e9 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -44,6 +44,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl +from typing_extensions import Self import mcp.types as types from mcp.server.models import InitializationOptions @@ -93,10 +94,16 @@ def __init__( ) self._init_options = init_options - self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[ - ServerRequestResponder - ](0) - self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose()) + + async def __aenter__(self) -> Self: + await super().__aenter__() + self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( + anyio.create_memory_object_stream[ServerRequestResponder](0) + ) + self._exit_stack.push_async_callback( + self._incoming_message_stream_reader.aclose + ) + return self @property def client_params(self) -> types.InitializeRequestParams | None: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 294986acb..a8c9b29d7 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -7,6 +7,7 @@ import anyio import httpx +from anyio.abc import TaskGroup from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import BaseModel from typing_extensions import Self @@ -177,6 +178,8 @@ class BaseSession( _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _progress_callbacks: dict[RequestId, ProgressFnT] + _exit_stack: AsyncExitStack + _task_group: TaskGroup def __init__( self, @@ -196,12 +199,19 @@ def __init__( self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._progress_callbacks = {} - self._exit_stack = AsyncExitStack() async def __aenter__(self) -> Self: - self._task_group = anyio.create_task_group() - await self._task_group.__aenter__() - self._task_group.start_soon(self._receive_loop) + async with AsyncExitStack() as exit_stack: + self._task_group = await exit_stack.enter_async_context( + anyio.create_task_group() + ) + self._task_group.start_soon(self._receive_loop) + # Using BaseSession as a context manager should not block on exit (this + # would be very surprising behavior), so make sure to cancel the tasks + # in the task group. + exit_stack.callback(self._task_group.cancel_scope.cancel) + self._exit_stack = exit_stack.pop_all() + return self async def __aexit__( @@ -210,12 +220,7 @@ async def __aexit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> bool | None: - await self._exit_stack.aclose() - # Using BaseSession as a context manager should not block on exit (this - # would be very surprising behavior), so make sure to cancel the tasks - # in the task group. - self._task_group.cancel_scope.cancel() - return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + return await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) async def send_request( self, diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index de4eb70af..e514bb5f7 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -99,7 +99,7 @@ def oauth_token(): @pytest.fixture -async def oauth_provider(client_metadata, mock_storage): +def oauth_provider(client_metadata, mock_storage): async def mock_redirect_handler(url: str) -> None: pass diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 327d1a9e4..12c043fe7 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -334,15 +334,15 @@ async def mock_server(): ) async with ( + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, ClientSession( server_to_client_receive, client_to_server_send, ) as session, anyio.create_task_group() as tg, - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, ): tg.start_soon(mock_server) diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index af7e47993..000000000 --- a/tests/conftest.py +++ /dev/null @@ -1,6 +0,0 @@ -import pytest - - -@pytest.fixture -def anyio_backend(): - return "asyncio" diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 615e68efc..d61538c2a 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -87,16 +87,17 @@ async def replay_events_after( """Replay events after the specified ID.""" # Find the index of the last event ID start_index = None - for i, (_, event_id, _) in enumerate(self._events): + stream_id = None + for i, (stream_id_, event_id, _) in enumerate(self._events): if event_id == last_event_id: start_index = i + 1 + stream_id = stream_id_ break if start_index is None: # If event ID not found, start from beginning start_index = 0 - stream_id = None # Replay events for _, event_id, message in self._events[start_index:]: await send_callback(EventMessage(message, event_id)) @@ -1003,7 +1004,8 @@ async def test_streamablehttp_client_resumption(event_server): captured_session_id = None captured_resumption_token = None captured_notifications = [] - tool_started = False + tool_started_event = anyio.Event() + session_resumption_token_received_event = anyio.Event() async def message_handler( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, @@ -1013,12 +1015,12 @@ async def message_handler( # Look for our special notification that indicates the tool is running if isinstance(message.root, types.LoggingMessageNotification): if message.root.params.data == "Tool started": - nonlocal tool_started - tool_started = True + tool_started_event.set() async def on_resumption_token_update(token: str) -> None: nonlocal captured_resumption_token captured_resumption_token = token + session_resumption_token_received_event.set() # First, start the client session and begin the long-running tool async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as ( @@ -1055,8 +1057,8 @@ async def run_tool(): # Wait for the tool to start and at least one notification # and then kill the task group - while not tool_started or not captured_resumption_token: - await anyio.sleep(0.1) + await tool_started_event.wait() + await session_resumption_token_received_event.wait() tg.cancel_scope.cancel() # Store pre notifications and clear the captured notifications diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 5081f1d53..084043236 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -27,6 +27,8 @@ SERVER_NAME = "test_server_for_WS" +pytestmark = pytest.mark.parametrize("anyio_backend", ["asyncio"]) + @pytest.fixture def server_port() -> int: diff --git a/uv.lock b/uv.lock index 180d5a9c1..6ae56f94b 100644 --- a/uv.lock +++ b/uv.lock @@ -40,6 +40,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/68/f9e9bf6324c46e6b8396610aef90ad423ec3e18c9079547ceafea3dce0ec/anyio-4.5.0-py3-none-any.whl", hash = "sha256:fdeb095b7cc5a5563175eedd926ec4ae55413bb4be5770c424af0ba46ccb4a78", size = 89250, upload-time = "2024-09-19T09:28:42.699Z" }, ] +[package.optional-dependencies] +trio = [ + { name = "trio" }, +] + [[package]] name = "asttokens" version = "2.4.1" @@ -537,6 +542,7 @@ dependencies = [ { name = "python-multipart" }, { name = "sse-starlette" }, { name = "starlette" }, + { name = "typing-extensions" }, { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, ] @@ -554,6 +560,7 @@ ws = [ [package.dev-dependencies] dev = [ + { name = "anyio", extra = ["trio"] }, { name = "inline-snapshot" }, { name = "pyright" }, { name = "pytest" }, @@ -562,7 +569,6 @@ dev = [ { name = "pytest-pretty" }, { name = "pytest-xdist" }, { name = "ruff" }, - { name = "trio" }, ] docs = [ { name = "mkdocs" }, @@ -584,6 +590,7 @@ requires-dist = [ { name = "sse-starlette", specifier = ">=1.6.1" }, { name = "starlette", specifier = ">=0.27" }, { name = "typer", marker = "extra == 'cli'", specifier = ">=0.12.4" }, + { name = "typing-extensions", specifier = ">=4.12" }, { name = "uvicorn", marker = "sys_platform != 'emscripten'", specifier = ">=0.23.1" }, { name = "websockets", marker = "extra == 'ws'", specifier = ">=15.0.1" }, ] @@ -591,6 +598,7 @@ provides-extras = ["cli", "rich", "ws"] [package.metadata.requires-dev] dev = [ + { name = "anyio", extras = ["trio"] }, { name = "inline-snapshot", specifier = ">=0.23.0" }, { name = "pyright", specifier = ">=1.1.391" }, { name = "pytest", specifier = ">=8.3.4" }, @@ -599,7 +607,6 @@ dev = [ { name = "pytest-pretty", specifier = ">=1.2.0" }, { name = "pytest-xdist", specifier = ">=3.6.1" }, { name = "ruff", specifier = ">=0.8.5" }, - { name = "trio", specifier = ">=0.26.2" }, ] docs = [ { name = "mkdocs", specifier = ">=1.6.1" }, From 0ce5b79dc54377f30741e480f08df54faaefc161 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Thu, 12 Jun 2025 14:08:08 +0300 Subject: [PATCH 2/7] Fixed pre-commit errors --- src/mcp/client/streamable_http.py | 4 +--- src/mcp/server/session.py | 10 ++++------ src/mcp/shared/session.py | 4 +--- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 15fc3393a..1b025e83e 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -442,9 +442,7 @@ async def streamablehttp_client( async with create_mcp_http_client( headers=transport.request_headers, - timeout=httpx.Timeout( - transport.timeout, read=transport.sse_read_timeout - ), + timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), ) as client: async with anyio.create_task_group() as tg: # Define callbacks that need access to tg diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index df1dd93e9..61d654744 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -97,12 +97,10 @@ def __init__( async def __aenter__(self) -> Self: await super().__aenter__() - self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( - anyio.create_memory_object_stream[ServerRequestResponder](0) - ) - self._exit_stack.push_async_callback( - self._incoming_message_stream_reader.aclose - ) + self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[ + ServerRequestResponder + ](0) + self._exit_stack.push_async_callback(self._incoming_message_stream_reader.aclose) return self @property diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index a8c9b29d7..2ff29304a 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -202,9 +202,7 @@ def __init__( async def __aenter__(self) -> Self: async with AsyncExitStack() as exit_stack: - self._task_group = await exit_stack.enter_async_context( - anyio.create_task_group() - ) + self._task_group = await exit_stack.enter_async_context(anyio.create_task_group()) self._task_group.start_soon(self._receive_loop) # Using BaseSession as a context manager should not block on exit (this # would be very surprising behavior), so make sure to cancel the tasks From 83d23ad0e324c31911ed86a14d41e2290bff60e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Wed, 18 Jun 2025 16:56:20 +0300 Subject: [PATCH 3/7] Fixed uses of async generators and removed the pytest warning ignore --- pyproject.toml | 1 - src/mcp/client/streamable_http.py | 24 +++++++++++------------- tests/client/test_auth.py | 31 ++++++++++++++++--------------- tests/shared/test_sse.py | 18 ++++++++++-------- 4 files changed, 37 insertions(+), 37 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6e3c00d90..f20a81946 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,5 +126,4 @@ filterwarnings = [ "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel", # This is to avoid test failures on Trio due to httpx's failure to explicitly close # async generators - "ignore::pytest.PytestUnraisableExceptionWarning" ] diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 970c3c682..bfb0f0aa1 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -11,7 +11,6 @@ from contextlib import aclosing, asynccontextmanager from dataclasses import dataclass from datetime import timedelta -from typing import cast import anyio import httpx @@ -241,15 +240,16 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: event_source.response.raise_for_status() logger.debug("Resumption GET SSE connection established") - async for sse in event_source.aiter_sse(): - is_complete = await self._handle_sse_event( - sse, - ctx.read_stream_writer, - original_request_id, - ctx.metadata.on_resumption_token_update if ctx.metadata else None, - ) - if is_complete: - break + async with aclosing(event_source.aiter_sse()) as iterator: + async for sse in iterator: + is_complete = await self._handle_sse_event( + sse, + ctx.read_stream_writer, + original_request_id, + ctx.metadata.on_resumption_token_update if ctx.metadata else None, + ) + if is_complete: + break async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" @@ -320,9 +320,7 @@ async def _handle_sse_response( ) -> None: """Handle SSE response from the server.""" try: - event_source = EventSource(response) - sse_iter = cast(AsyncGenerator[ServerSentEvent], event_source.aiter_sse()) - async with aclosing(sse_iter) as items: + async with aclosing(EventSource(response).aiter_sse()) as items: async for sse in items: is_complete = await self._handle_sse_event( sse, diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index e514bb5f7..ef202facd 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -5,6 +5,7 @@ import base64 import hashlib import time +from contextlib import aclosing from unittest.mock import AsyncMock, Mock, patch from urllib.parse import parse_qs, urlparse @@ -654,17 +655,17 @@ async def test_async_auth_flow_401_response(self, oauth_provider, oauth_token): mock_response = Mock() mock_response.status_code = 401 - auth_flow = oauth_provider.async_auth_flow(request) - await auth_flow.__anext__() + async with aclosing(oauth_provider.async_auth_flow(request)) as auth_flow: + await auth_flow.__anext__() - # Send 401 response - try: - await auth_flow.asend(mock_response) - except StopAsyncIteration: - pass + # Send 401 response + try: + await auth_flow.asend(mock_response) + except StopAsyncIteration: + pass - # Should clear current tokens - assert oauth_provider._current_tokens is None + # Should clear current tokens + assert oauth_provider._current_tokens is None @pytest.mark.anyio async def test_async_auth_flow_no_token(self, oauth_provider): @@ -675,14 +676,14 @@ async def test_async_auth_flow_no_token(self, oauth_provider): patch.object(oauth_provider, "initialize") as mock_init, patch.object(oauth_provider, "ensure_token") as mock_ensure, ): - auth_flow = oauth_provider.async_auth_flow(request) - updated_request = await auth_flow.__anext__() + async with aclosing(oauth_provider.async_auth_flow(request)) as auth_flow: + updated_request = await auth_flow.__anext__() - mock_init.assert_called_once() - mock_ensure.assert_called_once() + mock_init.assert_called_once() + mock_ensure.assert_called_once() - # No Authorization header should be added if no token - assert "Authorization" not in updated_request.headers + # No Authorization header should be added if no token + assert "Authorization" not in updated_request.headers @pytest.mark.anyio async def test_scope_priority_client_metadata_first(self, oauth_provider, oauth_client_info): diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 4d8f7717e..43bb3320f 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -3,6 +3,7 @@ import socket import time from collections.abc import AsyncGenerator, Generator +from contextlib import aclosing import anyio import httpx @@ -160,14 +161,15 @@ async def connection_test() -> None: assert response.headers["content-type"] == "text/event-stream; charset=utf-8" line_number = 0 - async for line in response.aiter_lines(): - if line_number == 0: - assert line == "event: endpoint" - elif line_number == 1: - assert line.startswith("data: /messages/?session_id=") - else: - return - line_number += 1 + async with aclosing(response.aiter_lines()) as lines: + async for line in lines: + if line_number == 0: + assert line == "event: endpoint" + elif line_number == 1: + assert line.startswith("data: /messages/?session_id=") + else: + return + line_number += 1 # Add timeout to prevent test from hanging if it fails with anyio.fail_after(3): From 85c50f2208a90c6e0fb782fdbdf61dec793992fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Mon, 23 Jun 2025 13:39:53 +0300 Subject: [PATCH 4/7] Use spawn instead of fork to prevent sniffio detecting the wrong async library on the test server --- tests/server/test_sse_security.py | 3 ++- tests/server/test_streamable_http_security.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 43af35061..957db286b 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -68,7 +68,8 @@ async def handle_sse(request: Request): def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) + context = multiprocessing.get_context("spawn") + process = context.Process(target=run_server_with_settings, args=(port, security_settings)) process.start() # Give server time to start time.sleep(1) diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index eed791924..69e08efe1 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -75,7 +75,8 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) + context = multiprocessing.get_context("spawn") + process = context.Process(target=run_server_with_settings, args=(port, security_settings)) process.start() # Give server time to start time.sleep(1) From 95f05186c13814344a4a126493be63901e741ff5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Wed, 25 Jun 2025 22:26:00 +0300 Subject: [PATCH 5/7] Bumped httpx-sse to v0.4.1 --- pyproject.toml | 2 +- uv.lock | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f20a81946..67d00b608 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ dependencies = [ "anyio>=4.5", "httpx>=0.27", - "httpx-sse>=0.4", + "httpx-sse>=0.4.1", "pydantic>=2.7.2,<3.0.0", "starlette>=0.27", "python-multipart>=0.0.9", diff --git a/uv.lock b/uv.lock index 6ae56f94b..a33ea39b7 100644 --- a/uv.lock +++ b/uv.lock @@ -398,11 +398,11 @@ wheels = [ [[package]] name = "httpx-sse" -version = "0.4.0" +version = "0.4.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/4c/60/8f4281fa9bbf3c8034fd54c0e7412e66edbab6bc74c4996bd616f8d0406e/httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721", size = 12624, upload-time = "2023-12-22T08:01:21.083Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6e/fa/66bd985dd0b7c109a3bcb89272ee0bfb7e2b4d06309ad7b38ff866734b2a/httpx_sse-0.4.1.tar.gz", hash = "sha256:8f44d34414bc7b21bf3602713005c5df4917884f76072479b21f68befa4ea26e", size = 12998, upload-time = "2025-06-24T13:21:05.71Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e1/9b/a181f281f65d776426002f330c31849b86b31fc9d848db62e16f03ff739f/httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f", size = 7819, upload-time = "2023-12-22T08:01:19.89Z" }, + { url = "https://files.pythonhosted.org/packages/25/0a/6269e3473b09aed2dab8aa1a600c70f31f00ae1349bee30658f7e358a159/httpx_sse-0.4.1-py3-none-any.whl", hash = "sha256:cba42174344c3a5b06f255ce65b350880f962d99ead85e776f23c6618a377a37", size = 8054, upload-time = "2025-06-24T13:21:04.772Z" }, ] [[package]] @@ -581,7 +581,7 @@ docs = [ requires-dist = [ { name = "anyio", specifier = ">=4.5" }, { name = "httpx", specifier = ">=0.27" }, - { name = "httpx-sse", specifier = ">=0.4" }, + { name = "httpx-sse", specifier = ">=0.4.1" }, { name = "pydantic", specifier = ">=2.7.2,<3.0.0" }, { name = "pydantic-settings", specifier = ">=2.5.2" }, { name = "python-dotenv", marker = "extra == 'cli'", specifier = ">=1.0.0" }, From 7759fbe7f6d9508d65c8354f0640fb8692cf2ca6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Thu, 26 Jun 2025 17:10:01 +0300 Subject: [PATCH 6/7] Fixed flaky server-side SSE tests --- tests/server/test_sse_security.py | 85 ++++++++++--------------------- 1 file changed, 27 insertions(+), 58 deletions(-) diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 957db286b..4982e6fe4 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -4,6 +4,7 @@ import multiprocessing import socket import time +from contextlib import contextmanager import httpx import pytest @@ -66,30 +67,40 @@ async def handle_sse(request: Request): uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") +@contextmanager def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): """Start server in a separate process.""" context = multiprocessing.get_context("spawn") process = context.Process(target=run_server_with_settings, args=(port, security_settings)) process.start() - # Give server time to start - time.sleep(1) - return process + + # Wait until the designated port can be connected + max_attempts = 20 + for attempt in range(max_attempts): + try: + with socket.create_connection(("127.0.0.1", port)): + break + except ConnectionRefusedError: + time.sleep(0.1) + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + + try: + yield + finally: + process.terminate() + process.join() @pytest.mark.anyio async def test_sse_security_default_settings(server_port: int): """Test SSE with default security settings (protection disabled).""" - process = start_server_process(server_port) - - try: + with start_server_process(server_port): headers = {"Host": "evil.com", "Origin": "http://evil.com"} async with httpx.AsyncClient(timeout=5.0) as client: async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: assert response.status_code == 200 - finally: - process.terminate() - process.join() @pytest.mark.anyio @@ -97,9 +108,7 @@ async def test_sse_security_invalid_host_header(server_port: int): """Test SSE with invalid Host header.""" # Enable security by providing settings with an empty allowed_hosts list security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"]) - process = start_server_process(server_port, security_settings) - - try: + with start_server_process(server_port, security_settings): # Test with invalid host header headers = {"Host": "evil.com"} @@ -108,10 +117,6 @@ async def test_sse_security_invalid_host_header(server_port: int): assert response.status_code == 421 assert response.text == "Invalid Host header" - finally: - process.terminate() - process.join() - @pytest.mark.anyio async def test_sse_security_invalid_origin_header(server_port: int): @@ -120,9 +125,7 @@ async def test_sse_security_invalid_origin_header(server_port: int): security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://localhost:*"] ) - process = start_server_process(server_port, security_settings) - - try: + with start_server_process(server_port, security_settings): # Test with invalid origin header headers = {"Origin": "http://evil.com"} @@ -131,10 +134,6 @@ async def test_sse_security_invalid_origin_header(server_port: int): assert response.status_code == 400 assert response.text == "Invalid Origin header" - finally: - process.terminate() - process.join() - @pytest.mark.anyio async def test_sse_security_post_invalid_content_type(server_port: int): @@ -143,9 +142,7 @@ async def test_sse_security_post_invalid_content_type(server_port: int): security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) - - try: + with start_server_process(server_port, security_settings): async with httpx.AsyncClient(timeout=5.0) as client: # Test POST with invalid content type fake_session_id = "12345678123456781234567812345678" @@ -164,18 +161,12 @@ async def test_sse_security_post_invalid_content_type(server_port: int): assert response.status_code == 400 assert response.text == "Invalid Content-Type header" - finally: - process.terminate() - process.join() - @pytest.mark.anyio async def test_sse_security_disabled(server_port: int): """Test SSE with security disabled.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) - process = start_server_process(server_port, settings) - - try: + with start_server_process(server_port, settings): # Test with invalid host header - should still work headers = {"Host": "evil.com"} @@ -185,10 +176,6 @@ async def test_sse_security_disabled(server_port: int): # Should connect successfully even with invalid host assert response.status_code == 200 - finally: - process.terminate() - process.join() - @pytest.mark.anyio async def test_sse_security_custom_allowed_hosts(server_port: int): @@ -198,9 +185,7 @@ async def test_sse_security_custom_allowed_hosts(server_port: int): allowed_hosts=["localhost", "127.0.0.1", "custom.host"], allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], ) - process = start_server_process(server_port, settings) - - try: + with start_server_process(server_port, settings): # Test with custom allowed host headers = {"Host": "custom.host"} @@ -218,10 +203,6 @@ async def test_sse_security_custom_allowed_hosts(server_port: int): assert response.status_code == 421 assert response.text == "Invalid Host header" - finally: - process.terminate() - process.join() - @pytest.mark.anyio async def test_sse_security_wildcard_ports(server_port: int): @@ -231,9 +212,7 @@ async def test_sse_security_wildcard_ports(server_port: int): allowed_hosts=["localhost:*", "127.0.0.1:*"], allowed_origins=["http://localhost:*", "http://127.0.0.1:*"], ) - process = start_server_process(server_port, settings) - - try: + with start_server_process(server_port, settings): # Test with various port numbers for test_port in [8080, 3000, 9999]: headers = {"Host": f"localhost:{test_port}"} @@ -252,10 +231,6 @@ async def test_sse_security_wildcard_ports(server_port: int): # Should connect successfully with any port assert response.status_code == 200 - finally: - process.terminate() - process.join() - @pytest.mark.anyio async def test_sse_security_post_valid_content_type(server_port: int): @@ -264,9 +239,7 @@ async def test_sse_security_post_valid_content_type(server_port: int): security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) - - try: + with start_server_process(server_port, security_settings): async with httpx.AsyncClient() as client: # Test with various valid content types valid_content_types = [ @@ -288,7 +261,3 @@ async def test_sse_security_post_valid_content_type(server_port: int): # We're testing that it passes the content-type check assert response.status_code == 404 assert response.text == "Could not find session" - - finally: - process.terminate() - process.join() From 62aeec995d28e7f749c86b757f0fbce150472a96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Thu, 26 Jun 2025 19:05:04 +0300 Subject: [PATCH 7/7] Removed obsolete comment --- pyproject.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2851be20b..ecbf4c6d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,6 +125,4 @@ filterwarnings = [ "ignore::DeprecationWarning:websockets", "ignore:websockets.server.WebSocketServerProtocol is deprecated:DeprecationWarning", "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel", - # This is to avoid test failures on Trio due to httpx's failure to explicitly close - # async generators ]