diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/DefaultSseMessageEndpointValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/DefaultSseMessageEndpointValidator.java new file mode 100644 index 000000000..5ee9b85fd --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/DefaultSseMessageEndpointValidator.java @@ -0,0 +1,58 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ +package io.modelcontextprotocol.client.transport; + +import java.net.URI; +import java.net.URISyntaxException; + +import io.modelcontextprotocol.util.Assert; + +/** + * Default {@link SseMessageEndpointValidator} that validates the {@code message} endpoint + * advertised by an SSE server. Message endpoints must be a relative URI, without path + * traversal or authority. + * + * @author Daniel Garnier-Moiroux + */ +public final class DefaultSseMessageEndpointValidator implements SseMessageEndpointValidator { + + @Override + public void validate(URI sseUri, String messageEndpoint) throws InvalidSseMessageEndpointException { + Assert.hasText(messageEndpoint, "messageEndpoint must not be empty"); + + URI endpointUri; + try { + endpointUri = new URI(messageEndpoint); + } + catch (URISyntaxException ex) { + throw new InvalidSseMessageEndpointException("messageEndpoint is not a valid URI: " + ex.getMessage(), + messageEndpoint); + } + + if (endpointUri.isAbsolute()) { + // Exclude absolute URIs e.g. https://example.com/mcp + throw new InvalidSseMessageEndpointException("messageEndpoint must be a relative path, not an absolute URI", + messageEndpoint); + } + + if (endpointUri.getRawAuthority() != null) { + // Exclude network paths e.g. //example.com/mcp + throw new InvalidSseMessageEndpointException( + "messageEndpoint must be a relative path and must not contain an authority", messageEndpoint); + } + + // Exclude path-traversal + String decodedPath = endpointUri.getPath(); + if (decodedPath != null) { + for (String segment : decodedPath.split("/", -1)) { + if (".".equals(segment) || "..".equals(segment)) { + throw new InvalidSseMessageEndpointException( + "messageEndpoint must not contain path-traversal segments", messageEndpoint); + } + } + } + + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 70d8b68e3..050c7dd9a 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -16,8 +16,6 @@ import java.util.function.Consumer; import java.util.function.Function; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; @@ -33,6 +31,8 @@ import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.Disposable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -117,6 +117,11 @@ public class HttpClientSseClientTransport implements McpClientTransport { */ private final McpAsyncHttpClientRequestCustomizer httpRequestCustomizer; + /** + * Validator for the message endpoint; + */ + private final SseMessageEndpointValidator messageEndpointValidator; + /** * Creates a new transport instance with custom HTTP client builder, object mapper, * and headers. @@ -127,22 +132,26 @@ public class HttpClientSseClientTransport implements McpClientTransport { * @param jsonMapper the object mapper for JSON serialization/deserialization * @param httpRequestCustomizer customizer for the requestBuilder before executing * requests + * @param messageEndpointValidator validator for the message endpoint * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null */ HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, - String sseEndpoint, McpJsonMapper jsonMapper, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer) { + String sseEndpoint, McpJsonMapper jsonMapper, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer, + SseMessageEndpointValidator messageEndpointValidator) { Assert.notNull(jsonMapper, "jsonMapper must not be null"); Assert.hasText(baseUri, "baseUri must not be empty"); Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); Assert.notNull(httpClient, "httpClient must not be null"); Assert.notNull(requestBuilder, "requestBuilder must not be null"); Assert.notNull(httpRequestCustomizer, "httpRequestCustomizer must not be null"); + Assert.notNull(messageEndpointValidator, "messageEndpointValidator must not be null"); this.baseUri = URI.create(baseUri); this.sseEndpoint = sseEndpoint; this.jsonMapper = jsonMapper; this.httpClient = httpClient; this.requestBuilder = requestBuilder; this.httpRequestCustomizer = httpRequestCustomizer; + this.messageEndpointValidator = messageEndpointValidator; } @Override @@ -178,6 +187,8 @@ public static class Builder { private Duration connectTimeout = Duration.ofSeconds(10); + private SseMessageEndpointValidator messageEndpointValidator = new DefaultSseMessageEndpointValidator(); + /** * Creates a new builder instance. */ @@ -297,6 +308,18 @@ public Builder connectTimeout(Duration connectTimeout) { return this; } + /** + * Sets the validator that ensure the message endpoint returned over the SSE + * connection is valid. + * @param messageEndpointValidator the validator + * @return this builder + */ + public Builder messageEndpointValidator(SseMessageEndpointValidator messageEndpointValidator) { + Assert.notNull(messageEndpointValidator, "messageEndpointValidator must not be null"); + this.messageEndpointValidator = messageEndpointValidator; + return this; + } + /** * Builds a new {@link HttpClientSseClientTransport} instance. * @return a new transport instance @@ -304,7 +327,8 @@ public Builder connectTimeout(Duration connectTimeout) { public HttpClientSseClientTransport build() { HttpClient httpClient = this.clientBuilder.connectTimeout(this.connectTimeout).build(); return new HttpClientSseClientTransport(httpClient, requestBuilder, baseUri, sseEndpoint, - jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, httpRequestCustomizer); + jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, httpRequestCustomizer, + messageEndpointValidator); } } @@ -342,6 +366,14 @@ public Mono connect(Function, Mono> h try { if (ENDPOINT_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { String messageEndpointUri = responseEvent.sseEvent().data(); + try { + messageEndpointValidator.validate(uri, messageEndpointUri); + } + catch (InvalidSseMessageEndpointException e) { + sink.error(e); + this.messageEndpointSink.tryEmitError(e); + return Flux.error(e); + } if (this.messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) { sink.success(); return Flux.empty(); // No further processing needed diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/InvalidSseMessageEndpointException.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/InvalidSseMessageEndpointException.java new file mode 100644 index 000000000..6acdfae51 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/InvalidSseMessageEndpointException.java @@ -0,0 +1,26 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +/** + * Exception thrown when the {@code message} endpoint returned from the SSE connection is + * not valid. + * + * @author Daniel Garnier-Moiroux + */ +public class InvalidSseMessageEndpointException extends Exception { + + private final String messageEndpoint; + + public InvalidSseMessageEndpointException(String message, String messageEndpoint) { + super(message); + this.messageEndpoint = messageEndpoint; + } + + public String getMessageEndpoint() { + return messageEndpoint; + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/SseMessageEndpointValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/SseMessageEndpointValidator.java new file mode 100644 index 000000000..322e64638 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/SseMessageEndpointValidator.java @@ -0,0 +1,27 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import java.net.URI; + +/** + * Validate the that message endpoint in the SSE transport is valid. Throws + * {@link InvalidSseMessageEndpointException} when then endpoint is not valid. + * + * @author Daniel Garnier-Moiroux + */ +@FunctionalInterface +public interface SseMessageEndpointValidator { + + /** + * Validate the message endpoint coming from an SSE connection. Throws if not valid. + * @param sseUri the URI used to establish the SSE connection + * @param messageEndpoint the message endpoint from the SSE connection + * @throws InvalidSseMessageEndpointException error thrown if the message endpoint is + * not valid. + */ + void validate(URI sseUri, String messageEndpoint) throws InvalidSseMessageEndpointException; + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/DefaultSseMessageEndpointValidatorTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/DefaultSseMessageEndpointValidatorTests.java new file mode 100644 index 000000000..f1fc82850 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/DefaultSseMessageEndpointValidatorTests.java @@ -0,0 +1,75 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ +package io.modelcontextprotocol.client.transport; + +import java.net.URI; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.NullSource; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.InstanceOfAssertFactories.type; + +/** + * Tests for {@link DefaultSseMessageEndpointValidator}. + * + * @author Daniel Garnier-Moiroux + */ +class DefaultSseMessageEndpointValidatorTests { + + private static final URI SSE_URI = URI.create("https://mcp.example.com/sse"); + + private final DefaultSseMessageEndpointValidator validator = new DefaultSseMessageEndpointValidator(); + + @ParameterizedTest + @ValueSource(strings = { "/messages", "messages?session=abc", "/" }) + void valid(String endpoint) { + assertThatCode(() -> validator.validate(SSE_URI, endpoint)).doesNotThrowAnyException(); + } + + @ParameterizedTest + @ValueSource(strings = { "", " ", "\t" }) + @NullSource + void invalidEmpty(String endpoint) { + assertThatThrownBy(() -> validator.validate(SSE_URI, endpoint)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("messageEndpoint must not be empty"); + } + + @ParameterizedTest + @ValueSource(strings = { "/foo/../bar", "/foo/./bar", "../bar", "./bar", "/foo/%2E%2E/bar", "/foo/%2e/bar" }) + void invalidPathTraversal(String endpoint) { + assertThatThrownBy(() -> validator.validate(SSE_URI, endpoint)).hasMessageContaining("path-traversal") + .asInstanceOf(type(InvalidSseMessageEndpointException.class)) + .extracting(InvalidSseMessageEndpointException::getMessageEndpoint) + .isEqualTo(endpoint); + } + + @ParameterizedTest + @ValueSource(strings = { "https://mcp.example.com/messages", "https://127.0.0.1/messages", + "https://mcp.example.com:8443/messages", "http://localhost:1234/messages", "file:///etc/passwd", + "gopher://mcp.example.com/_test" }) + void invalidAbsoluteUris(String endpoint) { + // Even an absolute URI on the same origin must be rejected: the contract + // is that the messageEndpoint is a path-only relative reference. + assertThatThrownBy(() -> validator.validate(SSE_URI, endpoint)).hasMessageContaining("must be a relative path") + .asInstanceOf(type(InvalidSseMessageEndpointException.class)) + .extracting(InvalidSseMessageEndpointException::getMessageEndpoint) + .isEqualTo(endpoint); + + } + + @ParameterizedTest + @ValueSource(strings = { "//example/messages", "//user:secret@example/messages" }) + void invalidNetworkReference(String endpoint) { + // `//host/...` introduces an authority and is therefore not a pure path. + assertThatThrownBy(() -> validator.validate(SSE_URI, endpoint)) + .hasMessageContaining("must not contain an authority") + .asInstanceOf(type(InvalidSseMessageEndpointException.class)) + .extracting(InvalidSseMessageEndpointException::getMessageEndpoint) + .isEqualTo(endpoint); + } + +} diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index f3bc17f5b..cb17e9fbf 100644 --- a/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -11,7 +11,6 @@ import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; @@ -19,7 +18,6 @@ import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; - import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; @@ -35,13 +33,13 @@ import org.springframework.http.codec.ServerSentEvent; import org.springframework.web.util.UriComponentsBuilder; - import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.ArgumentMatchers.matches; import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -66,6 +64,8 @@ class HttpClientSseClientTransportTests { private TestHttpClientSseClientTransport transport; + private SseMessageEndpointValidator sseMessageEndpointValidator = mock(SseMessageEndpointValidator.class); + private final McpTransportContext context = McpTransportContext.create(Map.of("some-key", "some-value")); // Test class to access protected methods @@ -75,10 +75,11 @@ static class TestHttpClientSseClientTransport extends HttpClientSseClientTranspo private Sinks.Many> events = Sinks.many().unicast().onBackpressureBuffer(); - public TestHttpClientSseClientTransport(final String baseUri) { + public TestHttpClientSseClientTransport(final String baseUri, + SseMessageEndpointValidator sseMessageEndpointValidator) { super(HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1).build(), HttpRequest.newBuilder().header("Content-Type", "application/json"), baseUri, "/sse", JSON_MAPPER, - McpAsyncHttpClientRequestCustomizer.NOOP); + McpAsyncHttpClientRequestCustomizer.NOOP, sseMessageEndpointValidator); } public int getInboundMessageCount() { @@ -112,7 +113,7 @@ static void stopContainer() { @BeforeEach void setUp() { - transport = new TestHttpClientSseClientTransport(host); + transport = new TestHttpClientSseClientTransport(host, sseMessageEndpointValidator); transport.connect(Function.identity()).block(); } @@ -417,4 +418,44 @@ void testAsyncRequestCustomizer() { customizedTransport.closeGracefully().block(); } + @Test + void testMessageEndpointValidation() throws InvalidSseMessageEndpointException { + var uriCaptor = ArgumentCaptor.forClass(URI.class); + verify(sseMessageEndpointValidator).validate(uriCaptor.capture(), matches("/message\\?sessionId=[a-z0-9-]+")); + assertThat(uriCaptor.getValue().toString()).matches("http://localhost:\\d+/sse"); + } + + @Test + void testMessageEndpointValidationRejects() { + TestHttpClientSseClientTransport transport = new TestHttpClientSseClientTransport(host, + (sseUri, messageEndpoint) -> { + throw new InvalidSseMessageEndpointException("boom", messageEndpoint); + }); + + try { + // fails to connect + StepVerifier.create(transport.connect(Function.identity())) + .verifyErrorMatches(HttpClientSseClientTransportTests::isInvalidEndpointError); + + // Since connection failed, there is no message endpoint, and no message can + // be sent + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + Map.of("key", "value")); + + StepVerifier.create(transport.sendMessage(testMessage)) + .verifyErrorMatches(HttpClientSseClientTransportTests::isInvalidEndpointError); + } + finally { + transport.closeGracefully(); + } + } + + private static boolean isInvalidEndpointError(Throwable e) { + if (e instanceof InvalidSseMessageEndpointException ismee) { + return ismee.getMessageEndpoint().matches("/message\\?sessionId=[a-z0-9-]+") + && ismee.getMessage().equals("boom"); + } + return false; + } + }