Skip to content
113 changes: 113 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,22 @@
from typing import Optional
from collections.abc import Iterator

try:
from anyio import create_memory_object_stream, create_task_group
from mcp.types import (
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
)
from mcp.shared.message import SessionMessage
except ImportError:
create_memory_object_stream = None
create_task_group = None
JSONRPCMessage = None
JSONRPCNotification = None
JSONRPCRequest = None
SessionMessage = None


SENTRY_EVENT_SCHEMA = "./checkouts/data-schemas/relay/event.schema.json"

Expand Down Expand Up @@ -592,6 +608,103 @@ def suppress_deprecation_warnings():
yield


@pytest.fixture
def get_initialization_payload():
def inner(request_id: str):
return SessionMessage( # type: ignore
message=JSONRPCMessage( # type: ignore
root=JSONRPCRequest( # type: ignore
jsonrpc="2.0",
id=request_id,
method="initialize",
params={
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.0.0"},
},
)
)
)

return inner


@pytest.fixture
def get_initialized_notification_payload():
def inner():
return SessionMessage( # type: ignore
message=JSONRPCMessage( # type: ignore
root=JSONRPCNotification( # type: ignore
jsonrpc="2.0",
method="notifications/initialized",
)
)
)

return inner


@pytest.fixture
def get_mcp_command_payload():
def inner(method: str, params, request_id: str):
return SessionMessage( # type: ignore
message=JSONRPCMessage( # type: ignore
root=JSONRPCRequest( # type: ignore
jsonrpc="2.0",
id=request_id,
method=method,
params=params,
)
)
)

return inner


@pytest.fixture
def stdio(
get_initialization_payload,
get_initialized_notification_payload,
get_mcp_command_payload,
):
async def inner(server, method: str, params, request_id: str):
read_stream_writer, read_stream = create_memory_object_stream(0) # type: ignore
write_stream, write_stream_reader = create_memory_object_stream(0) # type: ignore

result = {}

async def run_server():
await server.run(
read_stream, write_stream, server.create_initialization_options()
)

async def simulate_client(tg, result):
init_request = get_initialization_payload("1")
await read_stream_writer.send(init_request)

await write_stream_reader.receive()

initialized_notification = get_initialized_notification_payload()
await read_stream_writer.send(initialized_notification)

request = get_mcp_command_payload(
method, params=params, request_id=request_id
)
await read_stream_writer.send(request)

result["response"] = await write_stream_reader.receive()

tg.cancel_scope.cancel()

async with create_task_group() as tg: # type: ignore
tg.start_soon(run_server)
tg.start_soon(simulate_client, tg, result)

return result["response"]

return inner


class MockServerRequestHandler(BaseHTTPRequestHandler):
def do_GET(self): # noqa: N802
# Process an HTTP GET request and return a response.
Expand Down
Loading
Loading