Skip to content

Improved Trio support #946

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ classifiers = [
dependencies = [
"anyio>=4.5",
"httpx>=0.27",
"httpx-sse>=0.4",
"httpx-sse>=0.4.1",
"pydantic>=2.7.2,<3.0.0",
"starlette>=0.27",
"python-multipart>=0.0.9",
"sse-starlette>=1.6.1",
"pydantic-settings>=2.5.2",
"typing_extensions>=4.12",
"uvicorn>=0.23.1; sys_platform != 'emscripten'",
"jsonschema>=4.20.0",
]
Expand All @@ -49,10 +50,10 @@ required-version = ">=0.7.2"

[dependency-groups]
dev = [
"anyio[trio]",
"pyright>=1.1.391",
"pytest>=8.3.4",
"ruff>=0.8.5",
"trio>=0.26.2",
"pytest-flakefinder>=1.1.0",
"pytest-xdist>=3.6.1",
"pytest-examples>=0.0.14",
Expand Down Expand Up @@ -123,5 +124,5 @@ filterwarnings = [
# This should be fixed on Uvicorn's side.
"ignore::DeprecationWarning:websockets",
"ignore:websockets.server.WebSocketServerProtocol is deprecated:DeprecationWarning",
"ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel"
"ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel",
]
66 changes: 33 additions & 33 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import logging
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from contextlib import aclosing, asynccontextmanager
from dataclasses import dataclass
from datetime import timedelta

Expand Down Expand Up @@ -240,15 +240,16 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established")

async for sse in event_source.aiter_sse():
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
break
async with aclosing(event_source.aiter_sse()) as iterator:
async for sse in iterator:
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
break

async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
Expand Down Expand Up @@ -319,18 +320,18 @@ async def _handle_sse_response(
) -> None:
"""Handle SSE response from the server."""
try:
event_source = EventSource(response)
async for sse in event_source.aiter_sse():
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
is_initialization=is_initialization,
)
# If the SSE event indicates completion, like returning respose/error
# break the loop
if is_complete:
break
async with aclosing(EventSource(response).aiter_sse()) as items:
async for sse in items:
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
is_initialization=is_initialization,
)
# If the SSE event indicates completion, like returning respose/error
# break the loop
if is_complete:
break
except Exception as e:
logger.exception("Error reading SSE stream:")
await ctx.read_stream_writer.send(e)
Expand Down Expand Up @@ -471,15 +472,14 @@ async def streamablehttp_client(
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)

async with anyio.create_task_group() as tg:
try:
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
try:
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")

async with httpx_client_factory(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
auth=transport.auth,
) as client:
async with create_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
) as client:
async with anyio.create_task_group() as tg:
# Define callbacks that need access to tg
def start_get_stream() -> None:
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
Expand All @@ -504,6 +504,6 @@ def start_get_stream() -> None:
if transport.session_id and terminate_on_close:
await transport.terminate_session(client)
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
7 changes: 6 additions & 1 deletion src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
from typing_extensions import Self

import mcp.types as types
from mcp.server.models import InitializationOptions
Expand Down Expand Up @@ -93,10 +94,14 @@ def __init__(
)

self._init_options = init_options

async def __aenter__(self) -> Self:
await super().__aenter__()
self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[
ServerRequestResponder
](0)
self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose())
self._exit_stack.push_async_callback(self._incoming_message_stream_reader.aclose)
return self

@property
def client_params(self) -> types.InitializeRequestParams | None:
Expand Down
23 changes: 13 additions & 10 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import anyio
import httpx
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import BaseModel
from typing_extensions import Self
Expand Down Expand Up @@ -177,6 +178,8 @@ class BaseSession(
_request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
_progress_callbacks: dict[RequestId, ProgressFnT]
_exit_stack: AsyncExitStack
_task_group: TaskGroup

def __init__(
self,
Expand All @@ -196,12 +199,17 @@ def __init__(
self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._progress_callbacks = {}
self._exit_stack = AsyncExitStack()

async def __aenter__(self) -> Self:
self._task_group = anyio.create_task_group()
await self._task_group.__aenter__()
self._task_group.start_soon(self._receive_loop)
async with AsyncExitStack() as exit_stack:
self._task_group = await exit_stack.enter_async_context(anyio.create_task_group())
self._task_group.start_soon(self._receive_loop)
# Using BaseSession as a context manager should not block on exit (this
# would be very surprising behavior), so make sure to cancel the tasks
# in the task group.
exit_stack.callback(self._task_group.cancel_scope.cancel)
self._exit_stack = exit_stack.pop_all()

return self

async def __aexit__(
Expand All @@ -210,12 +218,7 @@ async def __aexit__(
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
await self._exit_stack.aclose()
# Using BaseSession as a context manager should not block on exit (this
# would be very surprising behavior), so make sure to cancel the tasks
# in the task group.
self._task_group.cancel_scope.cancel()
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
return await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)

async def send_request(
self,
Expand Down
8 changes: 4 additions & 4 deletions tests/client/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,15 +334,15 @@ async def mock_server():
)

async with (
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
ClientSession(
server_to_client_receive,
client_to_server_send,
) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)

Expand Down
6 changes: 0 additions & 6 deletions tests/conftest.py

This file was deleted.

Loading
Loading