Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ This document contains critical information about working with this codebase. Fo
- Bug fixes require regression tests
- IMPORTANT: The `tests/client/test_client.py` is the most well designed test file. Follow its patterns.
- IMPORTANT: Be minimal, and focus on E2E tests: Use the `mcp.client.Client` whenever possible.
- IMPORTANT: Do NOT test private functions (prefixed with `_`). Test them indirectly through the public API.

Test files mirror the source tree: `src/mcp/client/streamable_http.py` → `tests/client/test_streamable_http.py`
Add tests to the existing file for that module.
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"pyjwt[crypto]>=2.10.1",
"typing-extensions>=4.13.0",
"typing-inspection>=0.4.1",
"opentelemetry-api>=1.28.0",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -71,6 +72,7 @@ dev = [
"coverage[toml]>=7.10.7,<=7.13",
"pillow>=12.0",
"strict-no-cover",
"opentelemetry-sdk>=1.28.0",
]
docs = [
"mkdocs>=1.6.1",
Expand Down
30 changes: 29 additions & 1 deletion src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from opentelemetry import trace
from pydantic import BaseModel, TypeAdapter
from typing_extensions import Self

from mcp.shared.exceptions import MCPError
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
from mcp.shared.response_router import ResponseRouter
from mcp.shared.tracing import end_span_error, end_span_ok, start_client_span, start_server_span
from mcp.types import (
CONNECTION_CLOSED,
INVALID_PARAMS,
Expand Down Expand Up @@ -77,6 +79,7 @@ def __init__(
session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT],
on_complete: Callable[[RequestResponder[ReceiveRequestT, SendResultT]], Any],
message_metadata: MessageMetadata = None,
span: trace.Span | None = None,
) -> None:
self.request_id = request_id
self.request_meta = request_meta
Expand All @@ -87,6 +90,7 @@ def __init__(
self._cancel_scope = anyio.CancelScope()
self._on_complete = on_complete
self._entered = False # Track if we're in a context manager
self._span = span

def __enter__(self) -> RequestResponder[ReceiveRequestT, SendResultT]:
"""Enter the context manager, enabling request cancellation tracking."""
Expand Down Expand Up @@ -126,6 +130,12 @@ async def respond(self, response: SendResultT | ErrorData) -> None:
if not self.cancelled: # pragma: no branch
self._completed = True

if self._span is not None:
if isinstance(response, ErrorData):
end_span_error(self._span, MCPError(code=response.code, message=response.message))
else:
end_span_ok(self._span)

await self._session._send_response( # type: ignore[reportPrivateUsage]
request_id=self.request_id, response=response
)
Expand All @@ -139,6 +149,10 @@ async def cancel(self) -> None:

self._cancel_scope.cancel()
self._completed = True # Mark as completed so it's removed from in_flight

if self._span is not None:
end_span_error(self._span, MCPError(code=0, message="Request cancelled"))

# Send an error response to indicate cancellation
await self._session._send_response( # type: ignore[reportPrivateUsage]
request_id=self.request_id,
Expand Down Expand Up @@ -260,6 +274,9 @@ async def send_request(
# Store the callback for this request
self._progress_callbacks[request_id] = progress_callback

method = request_data["method"]
span = start_client_span(method, request_data.get("params"))

try:
jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data)
await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata))
Expand All @@ -278,7 +295,15 @@ async def send_request(
if isinstance(response_or_error, JSONRPCError):
raise MCPError.from_jsonrpc_error(response_or_error)
else:
return result_type.model_validate(response_or_error.result, by_name=False)
result = result_type.model_validate(response_or_error.result, by_name=False)
if span is not None:
end_span_ok(span)
return result

except BaseException as exc:
if span is not None:
end_span_error(span, exc)
raise

finally:
self._response_streams.pop(request_id, None)
Expand Down Expand Up @@ -339,13 +364,16 @@ async def _receive_loop(self) -> None:
message.message.model_dump(by_alias=True, mode="json", exclude_none=True),
by_name=False,
)
request_data = message.message.model_dump(by_alias=True, mode="json", exclude_none=True)
server_span = start_server_span(request_data["method"], request_data.get("params"))
responder = RequestResponder(
request_id=message.message.id,
request_meta=validated_request.params.meta if validated_request.params else None,
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
message_metadata=message.metadata,
span=server_span,
)
self._in_flight[responder.request_id] = responder
await self._received_request(responder)
Expand Down
81 changes: 81 additions & 0 deletions src/mcp/shared/tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import annotations

from typing import Any

from opentelemetry import trace
from opentelemetry.trace import StatusCode

_tracer = trace.get_tracer("mcp")

_EXCLUDED_METHODS: frozenset[str] = frozenset({"notifications/message"})

# Semantic convention attribute keys
ATTR_MCP_METHOD_NAME = "mcp.method.name"
ATTR_ERROR_TYPE = "error.type"

# Methods that have a meaningful target name in params
_TARGET_PARAM_KEY: dict[str, str] = {
"tools/call": "name",
"prompts/get": "name",
"resources/read": "uri",
}


def _extract_target(method: str, params: dict[str, Any] | None) -> str | None:
"""Extract the target (e.g. tool name, prompt name) from request params."""
key = _TARGET_PARAM_KEY.get(method)
if key is None or params is None:
return None
value = params.get(key)
if isinstance(value, str):
return value
return None


def start_client_span(method: str, params: dict[str, Any] | None) -> trace.Span | None:
"""Start a CLIENT span for an outgoing MCP request.

Returns None if the method is excluded from tracing.
"""
if method in _EXCLUDED_METHODS:
return None

target = _extract_target(method, params)
span_name = f"{method} {target}" if target else method
span = _tracer.start_span(
span_name,
kind=trace.SpanKind.CLIENT,
attributes={ATTR_MCP_METHOD_NAME: method},
)
return span


def start_server_span(method: str, params: dict[str, Any] | None) -> trace.Span | None:
"""Start a SERVER span for an incoming MCP request.

Returns None if the method is excluded from tracing.
"""
if method in _EXCLUDED_METHODS:
return None

target = _extract_target(method, params)
span_name = f"{method} {target}" if target else method
span = _tracer.start_span(
span_name,
kind=trace.SpanKind.SERVER,
attributes={ATTR_MCP_METHOD_NAME: method},
)
return span


def end_span_ok(span: trace.Span) -> None:
"""Mark a span as successful and end it."""
span.set_status(StatusCode.OK)
span.end()


def end_span_error(span: trace.Span, error: BaseException) -> None:
"""Mark a span as errored and end it."""
span.set_status(StatusCode.ERROR, str(error))
span.set_attribute(ATTR_ERROR_TYPE, type(error).__qualname__)
span.end()
Loading
Loading