Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult:
"""

from enum import Enum
from typing import Any, TypeVar, overload
from typing import Any, TypeVar, get_args, overload

import anyio
import anyio.lowlevel
Expand Down Expand Up @@ -63,6 +63,12 @@ class InitializationState(Enum):
RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception
)

_KNOWN_CLIENT_REQUEST_METHODS = frozenset(
method
for request_type in get_args(types.ClientRequest)
if isinstance(method := request_type.model_fields["method"].default, str)
)


class ServerSession(
BaseSession[
Expand Down Expand Up @@ -104,6 +110,11 @@ def _receive_request_adapter(self) -> TypeAdapter[types.ClientRequest]:
def _receive_notification_adapter(self) -> TypeAdapter[types.ClientNotification]:
return types.client_notification_adapter

def _get_request_validation_error(self, request: types.JSONRPCRequest) -> types.ErrorData:
if request.method not in _KNOWN_CLIENT_REQUEST_METHODS:
return types.ErrorData(code=types.METHOD_NOT_FOUND, message="Method not found")
return super()._get_request_validation_error(request)

@property
def client_params(self) -> types.InitializeRequestParams | None:
return self._client_params
Expand Down
5 changes: 4 additions & 1 deletion src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,9 @@ def _receive_request_adapter(self) -> TypeAdapter[ReceiveRequestT]:
def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]:
raise NotImplementedError

def _get_request_validation_error(self, request: JSONRPCRequest) -> ErrorData:
return ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="")

async def _receive_loop(self) -> None:
async with self._read_stream, self._write_stream:
try:
Expand Down Expand Up @@ -363,7 +366,7 @@ async def _receive_loop(self) -> None:
error_response = JSONRPCError(
jsonrpc="2.0",
id=message.message.id,
error=ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data=""),
error=self._get_request_validation_error(message.message),
)
session_message = SessionMessage(message=error_response)
await self._write_stream.send(session_message)
Expand Down
54 changes: 54 additions & 0 deletions tests/issues/test_1561_invalid_method_names.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import anyio
import pytest

from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession
from mcp.shared.message import SessionMessage
from mcp.types import INVALID_PARAMS, METHOD_NOT_FOUND, JSONRPCError, JSONRPCRequest, ServerCapabilities


@pytest.mark.anyio
async def test_invalid_method_names_return_method_not_found() -> None:
read_send_stream, read_receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10)
write_send_stream, write_receive_stream = anyio.create_memory_object_stream[SessionMessage](10)

try:
async with ServerSession(
read_stream=read_receive_stream,
write_stream=write_send_stream,
init_options=InitializationOptions(
server_name="test_server",
server_version="1.0.0",
capabilities=ServerCapabilities(),
),
):
await read_send_stream.send(
SessionMessage(
message=JSONRPCRequest(jsonrpc="2.0", id=1, method="invalid/method", params={})
)
)

invalid_method_response = (await write_receive_stream.receive()).message

assert isinstance(invalid_method_response, JSONRPCError)
assert invalid_method_response.id == 1
assert invalid_method_response.error.code == METHOD_NOT_FOUND
assert invalid_method_response.error.message == "Method not found"

await read_send_stream.send(
SessionMessage(
message=JSONRPCRequest(jsonrpc="2.0", id=2, method="initialize")
)
)

malformed_known_method_response = (await write_receive_stream.receive()).message

assert isinstance(malformed_known_method_response, JSONRPCError)
assert malformed_known_method_response.id == 2
assert malformed_known_method_response.error.code == INVALID_PARAMS
assert malformed_known_method_response.error.message == "Invalid request parameters"
finally: # pragma: lax no cover
await read_send_stream.aclose()
await write_send_stream.aclose()
await read_receive_stream.aclose()
await write_receive_stream.aclose()
Loading