Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ dependencies = [
"pyjwt[crypto]>=2.10.1",
"typing-extensions>=4.13.0",
"typing-inspection>=0.4.1",
"opentelemetry-api>=1.28.0",
"opentelemetry-api>=1.30.0",
]
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bumped from 1.28.0: explicit_bucket_boundaries_advisory was added to Meter.create_histogram() in opentelemetry-api 1.30.0. The previous minimum caused a TypeError at runtime when the OTel proxy meter replayed histogram creation against the real meter.


[project.optional-dependencies]
Expand Down Expand Up @@ -72,7 +72,7 @@ dev = [
"coverage[toml]>=7.10.7,<=7.13",
"pillow>=12.0",
"strict-no-cover",
"logfire>=3.0.0",
"logfire>=3.20.0",
]
Copy link
Copy Markdown
Author

@verdie-g verdie-g Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bumped from 3.0.0: logfire's _ProxyMeter.create_histogram() did not forward unknown kwargs (including explicit_bucket_boundaries_advisory) until 3.20.0, causing a TypeError when logfire was configured as the meter provider. Also replaced capfire.get_collected_metrics() (added in logfire 4.0.0) with capfire.metrics_reader.get_metrics_data() which works from 3.0.0, see _get_mcp_metrics in test_otel.py.

docs = [
"mkdocs>=1.6.1",
Expand Down
49 changes: 45 additions & 4 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ async def main():

import contextvars
import logging
import time
import warnings
from collections.abc import AsyncIterator, Awaitable, Callable
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
Expand Down Expand Up @@ -66,7 +67,7 @@ async def main():
from mcp.server.streamable_http import EventStore
from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager
from mcp.server.transport_security import TransportSecuritySettings
from mcp.shared._otel import extract_trace_context, otel_span
from mcp.shared._otel import extract_trace_context, otel_span, record_server_operation_duration
from mcp.shared._stream_protocols import ReadStream, WriteStream
from mcp.shared.exceptions import MCPError
from mcp.shared.message import ServerMessageMetadata, SessionMessage
Expand Down Expand Up @@ -455,12 +456,36 @@ async def _handle_request(
meta = cast(dict[str, Any] | None, getattr(req.params, "meta", None)) if req.params else None
parent_context = extract_trace_context(meta) if meta is not None else None

mcp_protocol_version: str | None = (
str(session.client_params.protocol_version) if session.client_params else None
)

start_time = time.monotonic()

def _record_duration(
*,
error_type: str | None = None,
rpc_response_status_code: str | None = None,
) -> None:
record_server_operation_duration(
time.monotonic() - start_time,
req.method,
error_type=error_type,
rpc_response_status_code=rpc_response_status_code,
tool_name=target if req.method == "tools/call" else None,
prompt_name=target if req.method == "prompts/get" else None,
mcp_protocol_version=mcp_protocol_version,
)

with otel_span(
span_name,
kind=SpanKind.SERVER,
attributes={"mcp.method.name": req.method, "jsonrpc.request.id": message.request_id},
context=parent_context,
) as span:
error_type: str | None = None
rpc_response_status_code: str | None = None

if handler := self._request_handlers.get(req.method):
logger.debug("Dispatching request of type %s", type(req).__name__)

Expand Down Expand Up @@ -499,25 +524,38 @@ async def _handle_request(
)
response = await handler(ctx, req.params)
except MCPError as err:
rpc_response_status_code = str(err.error.code)
error_type = rpc_response_status_code
response = err.error
except anyio.get_cancelled_exc_class():
if message.cancelled:
# Client sent CancelledNotification; responder.cancel() already
# sent an error response, so skip the duplicate.
logger.info("Request %s cancelled - duplicate response suppressed", message.request_id)
_record_duration(error_type="cancelled")
return
# Transport-close cancellation from the TG in run(); re-raise so the
# TG swallows its own cancellation.
raise
except Exception as err:
if raise_exceptions: # pragma: no cover
error_type = type(err).__name__
if raise_exceptions:
_record_duration(error_type=error_type)
raise err
response = types.ErrorData(code=0, message=str(err))
else: # pragma: no cover
rpc_response_status_code = str(types.METHOD_NOT_FOUND)
error_type = rpc_response_status_code
response = types.ErrorData(code=types.METHOD_NOT_FOUND, message="Method not found")

if isinstance(response, types.ErrorData) and span is not None:
span.set_status(StatusCode.ERROR, response.message)
if isinstance(response, types.ErrorData):
if span is not None:
span.set_status(StatusCode.ERROR, response.message)
# Only set error_type/rpc_response_status_code from response code if not
# already set by an exception.
if error_type is None: # pragma: no cover
rpc_response_status_code = str(response.code) # pragma: no cover
error_type = rpc_response_status_code # pragma: no cover

try:
await message.respond(response)
Expand All @@ -529,10 +567,13 @@ async def _handle_request(
# end closed (_receive_loop's async-with exit); Broken if the peer
# end closed first (streamable_http terminate()).
logger.debug("Response for %s dropped - transport closed", message.request_id)
_record_duration(error_type=error_type, rpc_response_status_code=rpc_response_status_code)
return

logger.debug("Response sent")

_record_duration(error_type=error_type, rpc_response_status_code=rpc_response_status_code)

async def _handle_notification(
self,
notify: types.ClientNotification,
Expand Down
25 changes: 25 additions & 0 deletions src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult:
be instantiated directly by users of the MCP framework.
"""

import time
from enum import Enum
from typing import Any, TypeVar, overload

Expand All @@ -40,6 +41,7 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult:
from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures
from mcp.server.models import InitializationOptions
from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages
from mcp.shared._otel import record_server_session_duration
from mcp.shared._stream_protocols import ReadStream, WriteStream
from mcp.shared.exceptions import StatelessModeNotSupported
from mcp.shared.experimental.tasks.capabilities import check_tasks_capability
Expand Down Expand Up @@ -96,6 +98,29 @@ def __init__(
ServerRequestResponder
](0)
self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose())
self._session_start_time: float | None = None

async def __aenter__(self) -> "ServerSession":
self._session_start_time = time.monotonic()
return await super().__aenter__()

async def __aexit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any
) -> bool | None:
if self._session_start_time is not None: # pragma: no branch
duration = time.monotonic() - self._session_start_time
mcp_protocol_version: str | None = (
str(self._client_params.protocol_version) if self._client_params else None
)
# Cancellation exceptions indicate transport close, not a session error.
is_cancellation = exc_val is not None and isinstance(exc_val, anyio.get_cancelled_exc_class())
error_type: str | None = type(exc_val).__name__ if exc_val is not None and not is_cancellation else None
record_server_session_duration(
duration,
error_type=error_type,
mcp_protocol_version=mcp_protocol_version,
)
return await super().__aexit__(exc_type, exc_val, exc_tb)

@property
def _receive_request_adapter(self) -> TypeAdapter[types.ClientRequest]:
Expand Down
63 changes: 63 additions & 0 deletions src/mcp/shared/_otel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,32 @@
from typing import Any

from opentelemetry.context import Context
from opentelemetry.metrics import get_meter
from opentelemetry.propagate import extract, inject
from opentelemetry.trace import SpanKind, get_tracer

_tracer = get_tracer("mcp-python-sdk")
_meter = get_meter("mcp-python-sdk")

# Metrics as defined by the OTEL semconv https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/mcp.md
_DURATION_BUCKETS = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1, 2, 5, 10, 30, 60, 120, 300]

_server_operation_duration = _meter.create_histogram(
"mcp.server.operation.duration",
unit="s",
description=(
"MCP request or notification duration as observed on the receiver "
"from the time it was received until the result or ack is sent."
),
explicit_bucket_boundaries_advisory=_DURATION_BUCKETS,
)

_server_session_duration = _meter.create_histogram(
"mcp.server.session.duration",
unit="s",
description="The duration of the MCP session as observed on the MCP server.",
explicit_bucket_boundaries_advisory=_DURATION_BUCKETS,
)


@contextmanager
Expand All @@ -34,3 +56,44 @@ def inject_trace_context(meta: dict[str, Any]) -> None:
def extract_trace_context(meta: dict[str, Any]) -> Context:
"""Extract W3C trace context from a `_meta` dict."""
return extract(meta)


def record_server_operation_duration(
duration_s: float,
method: str,
*,
error_type: str | None = None,
rpc_response_status_code: str | None = None,
tool_name: str | None = None,
prompt_name: str | None = None,
mcp_protocol_version: str | None = None,
) -> None:
"""Record a data point for mcp.server.operation.duration."""
attributes: dict[str, str] = {"mcp.method.name": method}
if error_type is not None:
attributes["error.type"] = error_type
if rpc_response_status_code is not None:
attributes["rpc.response.status_code"] = rpc_response_status_code
if tool_name is not None:
attributes["gen_ai.tool.name"] = tool_name
attributes["gen_ai.operation.name"] = "execute_tool"
if prompt_name is not None:
attributes["gen_ai.prompt.name"] = prompt_name
if mcp_protocol_version is not None:
attributes["mcp.protocol.version"] = mcp_protocol_version
_server_operation_duration.record(duration_s, attributes)


def record_server_session_duration(
duration_s: float,
*,
error_type: str | None = None,
mcp_protocol_version: str | None = None,
) -> None:
"""Record a data point for mcp.server.session.duration."""
attributes: dict[str, str] = {}
if error_type is not None:
attributes["error.type"] = error_type
if mcp_protocol_version is not None:
attributes["mcp.protocol.version"] = mcp_protocol_version
_server_session_duration.record(duration_s, attributes)
94 changes: 92 additions & 2 deletions tests/shared/test_otel.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,35 @@
from __future__ import annotations

import json
from typing import Any, cast

import pytest
from logfire.testing import CaptureLogfire
from opentelemetry.sdk.metrics._internal.point import MetricsData

from mcp import types
from mcp.client.client import Client
from mcp.server.context import ServerRequestContext
from mcp.server.lowlevel.server import Server
from mcp.server.mcpserver import MCPServer
from mcp.shared.exceptions import MCPError

pytestmark = pytest.mark.anyio


def _get_mcp_metrics(capfire: CaptureLogfire) -> dict[str, Any]:
"""Return collected metrics whose name starts with 'mcp.', keyed by name."""
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CaptureLogfire.get_collected_metrics() was only added in logfire 4.0.0, but our minimum is 3.20.0. This helper replicates the same logic using capfire.metrics_reader.get_metrics_data(), which is available from logfire 3.0.0 (the metrics_reader field is part of the public CaptureLogfire dataclass).

exported = json.loads(cast(MetricsData, capfire.metrics_reader.get_metrics_data()).to_json())
[resource_metric] = exported["resource_metrics"]
all_metrics = [metric for scope_metric in resource_metric["scope_metrics"] for metric in scope_metric["metrics"]]
return {m["name"]: m for m in all_metrics if m["name"].startswith("mcp.")}


# Logfire warns about propagated trace context by default (distributed_tracing=None).
# This is expected here since we're testing cross-boundary context propagation.
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
async def test_client_and_server_spans(capfire: CaptureLogfire):
"""Verify that calling a tool produces client and server spans with correct attributes."""
async def test_client_and_server_instrumentation(capfire: CaptureLogfire):
"""Verify that calling a tool produces client and server spans and metrics with correct attributes."""
server = MCPServer("test")

@server.tool()
Expand Down Expand Up @@ -42,3 +57,78 @@ def greet(name: str) -> str:

# Server span should be in the same trace as the client span (context propagation).
assert server_span["context"]["trace_id"] == client_span["context"]["trace_id"]

metrics = _get_mcp_metrics(capfire)

assert "mcp.server.operation.duration" in metrics
assert "mcp.server.session.duration" in metrics

op_metric = metrics["mcp.server.operation.duration"]
assert op_metric["unit"] == "s"
op_points = op_metric["data"]["data_points"]

# tools/call data point
tools_call_point = next(p for p in op_points if p["attributes"]["mcp.method.name"] == "tools/call")
assert tools_call_point["attributes"]["gen_ai.tool.name"] == "greet"
assert tools_call_point["attributes"]["gen_ai.operation.name"] == "execute_tool"
assert tools_call_point["attributes"]["mcp.protocol.version"] == "2025-11-25"
assert tools_call_point["count"] == 1
assert tools_call_point["sum"] >= 0

# tools/list is also called during initialization
assert any(p["attributes"]["mcp.method.name"] == "tools/list" for p in op_points)

session_metric = metrics["mcp.server.session.duration"]
assert session_metric["unit"] == "s"
[session_point] = session_metric["data"]["data_points"]
assert session_point["attributes"]["mcp.protocol.version"] == "2025-11-25"
assert "error.type" not in session_point["attributes"]
assert session_point["count"] == 1
assert session_point["sum"] >= 0


@pytest.mark.filterwarnings("ignore::RuntimeWarning")
async def test_server_operation_error_metrics(capfire: CaptureLogfire):
"""Verify that error.type and rpc.response.status_code are set when a handler raises MCPError."""

async def handle_call_tool(
ctx: ServerRequestContext[Any], params: types.CallToolRequestParams
) -> types.CallToolResult:
raise MCPError(types.INVALID_PARAMS, "bad params")

server = Server("test", on_call_tool=handle_call_tool)

async with Client(server) as client:
with pytest.raises(MCPError):
await client.call_tool("boom", {})

metrics = _get_mcp_metrics(capfire)
op_points = metrics["mcp.server.operation.duration"]["data"]["data_points"]
error_point = next(p for p in op_points if p["attributes"]["mcp.method.name"] == "tools/call")
assert error_point["attributes"]["error.type"] == str(types.INVALID_PARAMS)
assert error_point["attributes"]["rpc.response.status_code"] == str(types.INVALID_PARAMS)


@pytest.mark.filterwarnings("ignore::RuntimeWarning")
async def test_server_session_error_metrics(capfire: CaptureLogfire):
"""Verify that error.type is set on session duration when the session exits with an exception."""

async def handle_call_tool(
ctx: ServerRequestContext[Any], params: types.CallToolRequestParams
) -> types.CallToolResult:
raise RuntimeError("unexpected crash")

server = Server("test", on_call_tool=handle_call_tool)

# raise_exceptions=True lets the RuntimeError escape the handler and crash the session,
# simulating what happens in production when an unhandled exception exits the session block.
with pytest.raises(Exception):
async with Client(server, raise_exceptions=True) as client:
await client.call_tool("boom", {})

metrics = _get_mcp_metrics(capfire)
session_points = metrics["mcp.server.session.duration"]["data"]["data_points"]
error_session_points = [p for p in session_points if "error.type" in p["attributes"]]
assert len(error_session_points) >= 1
# anyio wraps task group exceptions in ExceptionGroup
assert error_session_points[0]["attributes"]["error.type"] in ("RuntimeError", "ExceptionGroup")
9 changes: 4 additions & 5 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,7 +1497,7 @@ async def _handle_context_call_tool( # pragma: no cover
if name == "echo_headers":
headers_info: dict[str, Any] = {}
if ctx.request and isinstance(ctx.request, Request):
headers_info = dict(ctx.request.headers)
headers_info = dict(ctx.request.headers) # pyright: ignore[reportUnknownMemberType]
return CallToolResult(content=[TextContent(type="text", text=json.dumps(headers_info))])

elif name == "echo_context":
Expand All @@ -1508,10 +1508,9 @@ async def _handle_context_call_tool( # pragma: no cover
"path": None,
}
if ctx.request and isinstance(ctx.request, Request):
request = ctx.request
context_data["headers"] = dict(request.headers)
context_data["method"] = request.method
context_data["path"] = request.url.path
context_data["headers"] = dict(ctx.request.headers) # pyright: ignore[reportUnknownMemberType]
context_data["method"] = ctx.request.method # pyright: ignore[reportUnknownMemberType]
context_data["path"] = ctx.request.url.path # pyright: ignore[reportUnknownMemberType]
return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))])

return CallToolResult(content=[TextContent(type="text", text=f"Unknown tool: {name}")])
Expand Down
Loading
Loading