diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/DefaultSseMessageEndpointValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/DefaultSseMessageEndpointValidator.java new file mode 100644 index 000000000..5ee9b85fd --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/DefaultSseMessageEndpointValidator.java @@ -0,0 +1,58 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ +package io.modelcontextprotocol.client.transport; + +import java.net.URI; +import java.net.URISyntaxException; + +import io.modelcontextprotocol.util.Assert; + +/** + * Default {@link SseMessageEndpointValidator} that validates the {@code message} endpoint + * advertised by an SSE server. Message endpoints must be a relative URI, without path + * traversal or authority. + * + * @author Daniel Garnier-Moiroux + */ +public final class DefaultSseMessageEndpointValidator implements SseMessageEndpointValidator { + + @Override + public void validate(URI sseUri, String messageEndpoint) throws InvalidSseMessageEndpointException { + Assert.hasText(messageEndpoint, "messageEndpoint must not be empty"); + + URI endpointUri; + try { + endpointUri = new URI(messageEndpoint); + } + catch (URISyntaxException ex) { + throw new InvalidSseMessageEndpointException("messageEndpoint is not a valid URI: " + ex.getMessage(), + messageEndpoint); + } + + if (endpointUri.isAbsolute()) { + // Exclude absolute URIs e.g. https://example.com/mcp + throw new InvalidSseMessageEndpointException("messageEndpoint must be a relative path, not an absolute URI", + messageEndpoint); + } + + if (endpointUri.getRawAuthority() != null) { + // Exclude network paths e.g. //example.com/mcp + throw new InvalidSseMessageEndpointException( + "messageEndpoint must be a relative path and must not contain an authority", messageEndpoint); + } + + // Exclude path-traversal + String decodedPath = endpointUri.getPath(); + if (decodedPath != null) { + for (String segment : decodedPath.split("/", -1)) { + if (".".equals(segment) || "..".equals(segment)) { + throw new InvalidSseMessageEndpointException( + "messageEndpoint must not contain path-traversal segments", messageEndpoint); + } + } + } + + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 66e0b9d44..5a4ad3f99 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -16,8 +16,6 @@ import java.util.function.Consumer; import java.util.function.Function; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; @@ -33,6 +31,8 @@ import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.Disposable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -117,6 +117,11 @@ public class HttpClientSseClientTransport implements McpClientTransport { */ private final McpAsyncHttpClientRequestCustomizer httpRequestCustomizer; + /** + * Validator for the message endpoint; + */ + private final SseMessageEndpointValidator messageEndpointValidator; + /** * Creates a new transport instance with custom HTTP client builder, object mapper, * and headers. @@ -127,22 +132,26 @@ public class HttpClientSseClientTransport implements McpClientTransport { * @param jsonMapper the object mapper for JSON serialization/deserialization * @param httpRequestCustomizer customizer for the requestBuilder before executing * requests + * @param messageEndpointValidator validator for the message endpoint * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null */ HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, - String sseEndpoint, McpJsonMapper jsonMapper, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer) { + String sseEndpoint, McpJsonMapper jsonMapper, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer, + SseMessageEndpointValidator messageEndpointValidator) { Assert.notNull(jsonMapper, "jsonMapper must not be null"); Assert.hasText(baseUri, "baseUri must not be empty"); Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); Assert.notNull(httpClient, "httpClient must not be null"); Assert.notNull(requestBuilder, "requestBuilder must not be null"); Assert.notNull(httpRequestCustomizer, "httpRequestCustomizer must not be null"); + Assert.notNull(messageEndpointValidator, "messageEndpointValidator must not be null"); this.baseUri = URI.create(baseUri); this.sseEndpoint = sseEndpoint; this.jsonMapper = jsonMapper; this.httpClient = httpClient; this.requestBuilder = requestBuilder; this.httpRequestCustomizer = httpRequestCustomizer; + this.messageEndpointValidator = messageEndpointValidator; } @Override @@ -178,6 +187,8 @@ public static class Builder { private Duration connectTimeout = Duration.ofSeconds(10); + private SseMessageEndpointValidator messageEndpointValidator = new DefaultSseMessageEndpointValidator(); + /** * Creates a new builder instance. */ @@ -321,6 +332,18 @@ public Builder connectTimeout(Duration connectTimeout) { return this; } + /** + * Sets the validator that ensure the message endpoint returned over the SSE + * connection is valid. + * @param messageEndpointValidator the validator + * @return this builder + */ + public Builder messageEndpointValidator(SseMessageEndpointValidator messageEndpointValidator) { + Assert.notNull(messageEndpointValidator, "messageEndpointValidator must not be null"); + this.messageEndpointValidator = messageEndpointValidator; + return this; + } + /** * Builds a new {@link HttpClientSseClientTransport} instance. * @return a new transport instance @@ -328,7 +351,8 @@ public Builder connectTimeout(Duration connectTimeout) { public HttpClientSseClientTransport build() { HttpClient httpClient = this.clientBuilder.connectTimeout(this.connectTimeout).build(); return new HttpClientSseClientTransport(httpClient, requestBuilder, baseUri, sseEndpoint, - jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, httpRequestCustomizer); + jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, httpRequestCustomizer, + messageEndpointValidator); } } @@ -366,6 +390,14 @@ public Mono connect(Function, Mono> h try { if (ENDPOINT_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { String messageEndpointUri = responseEvent.sseEvent().data(); + try { + messageEndpointValidator.validate(uri, messageEndpointUri); + } + catch (InvalidSseMessageEndpointException e) { + sink.error(e); + this.messageEndpointSink.tryEmitError(e); + return Flux.error(e); + } if (this.messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) { sink.success(); return Flux.empty(); // No further processing needed diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/InvalidSseMessageEndpointException.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/InvalidSseMessageEndpointException.java new file mode 100644 index 000000000..6acdfae51 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/InvalidSseMessageEndpointException.java @@ -0,0 +1,26 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +/** + * Exception thrown when the {@code message} endpoint returned from the SSE connection is + * not valid. + * + * @author Daniel Garnier-Moiroux + */ +public class InvalidSseMessageEndpointException extends Exception { + + private final String messageEndpoint; + + public InvalidSseMessageEndpointException(String message, String messageEndpoint) { + super(message); + this.messageEndpoint = messageEndpoint; + } + + public String getMessageEndpoint() { + return messageEndpoint; + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/SseMessageEndpointValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/SseMessageEndpointValidator.java new file mode 100644 index 000000000..322e64638 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/SseMessageEndpointValidator.java @@ -0,0 +1,27 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import java.net.URI; + +/** + * Validate the that message endpoint in the SSE transport is valid. Throws + * {@link InvalidSseMessageEndpointException} when then endpoint is not valid. + * + * @author Daniel Garnier-Moiroux + */ +@FunctionalInterface +public interface SseMessageEndpointValidator { + + /** + * Validate the message endpoint coming from an SSE connection. Throws if not valid. + * @param sseUri the URI used to establish the SSE connection + * @param messageEndpoint the message endpoint from the SSE connection + * @throws InvalidSseMessageEndpointException error thrown if the message endpoint is + * not valid. + */ + void validate(URI sseUri, String messageEndpoint) throws InvalidSseMessageEndpointException; + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/DefaultSseMessageEndpointValidatorTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/DefaultSseMessageEndpointValidatorTests.java new file mode 100644 index 000000000..f1fc82850 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/DefaultSseMessageEndpointValidatorTests.java @@ -0,0 +1,75 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ +package io.modelcontextprotocol.client.transport; + +import java.net.URI; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.NullSource; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.InstanceOfAssertFactories.type; + +/** + * Tests for {@link DefaultSseMessageEndpointValidator}. + * + * @author Daniel Garnier-Moiroux + */ +class DefaultSseMessageEndpointValidatorTests { + + private static final URI SSE_URI = URI.create("https://mcp.example.com/sse"); + + private final DefaultSseMessageEndpointValidator validator = new DefaultSseMessageEndpointValidator(); + + @ParameterizedTest + @ValueSource(strings = { "/messages", "messages?session=abc", "/" }) + void valid(String endpoint) { + assertThatCode(() -> validator.validate(SSE_URI, endpoint)).doesNotThrowAnyException(); + } + + @ParameterizedTest + @ValueSource(strings = { "", " ", "\t" }) + @NullSource + void invalidEmpty(String endpoint) { + assertThatThrownBy(() -> validator.validate(SSE_URI, endpoint)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("messageEndpoint must not be empty"); + } + + @ParameterizedTest + @ValueSource(strings = { "/foo/../bar", "/foo/./bar", "../bar", "./bar", "/foo/%2E%2E/bar", "/foo/%2e/bar" }) + void invalidPathTraversal(String endpoint) { + assertThatThrownBy(() -> validator.validate(SSE_URI, endpoint)).hasMessageContaining("path-traversal") + .asInstanceOf(type(InvalidSseMessageEndpointException.class)) + .extracting(InvalidSseMessageEndpointException::getMessageEndpoint) + .isEqualTo(endpoint); + } + + @ParameterizedTest + @ValueSource(strings = { "https://mcp.example.com/messages", "https://127.0.0.1/messages", + "https://mcp.example.com:8443/messages", "http://localhost:1234/messages", "file:///etc/passwd", + "gopher://mcp.example.com/_test" }) + void invalidAbsoluteUris(String endpoint) { + // Even an absolute URI on the same origin must be rejected: the contract + // is that the messageEndpoint is a path-only relative reference. + assertThatThrownBy(() -> validator.validate(SSE_URI, endpoint)).hasMessageContaining("must be a relative path") + .asInstanceOf(type(InvalidSseMessageEndpointException.class)) + .extracting(InvalidSseMessageEndpointException::getMessageEndpoint) + .isEqualTo(endpoint); + + } + + @ParameterizedTest + @ValueSource(strings = { "//example/messages", "//user:secret@example/messages" }) + void invalidNetworkReference(String endpoint) { + // `//host/...` introduces an authority and is therefore not a pure path. + assertThatThrownBy(() -> validator.validate(SSE_URI, endpoint)) + .hasMessageContaining("must not contain an authority") + .asInstanceOf(type(InvalidSseMessageEndpointException.class)) + .extracting(InvalidSseMessageEndpointException::getMessageEndpoint) + .isEqualTo(endpoint); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java index 304a3435f..350a204ba 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java @@ -5,7 +5,9 @@ package io.modelcontextprotocol.client.transport; import java.io.IOException; +import java.net.URI; import java.util.List; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; import java.util.function.Function; @@ -128,7 +130,17 @@ public class WebFluxSseClientTransport implements McpClientTransport { * The SSE endpoint URI provided by the server. Used for sending outbound messages via * HTTP POST requests. */ - private String sseEndpoint; + private final String sseEndpoint; + + /** + * Used to capture the full SSE URI from the web client when connecting. + */ + private final AtomicReference sseUri = new AtomicReference<>(); + + /** + * Validator for the message endpoint. + */ + private final SseMessageEndpointValidator messageEndpointValidator; /** * Constructs a new SseClientTransport with the specified WebClient builder and @@ -152,13 +164,30 @@ public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, McpJsonMapp * @throws IllegalArgumentException if either parameter is null */ public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, McpJsonMapper jsonMapper, String sseEndpoint) { + this(webClientBuilder, jsonMapper, sseEndpoint, new DefaultSseMessageEndpointValidator()); + } + + /** + * Constructs a new SseClientTransport with the specified WebClient builder and + * ObjectMapper. Initializes both inbound and outbound message processing pipelines. + * @param webClientBuilder the WebClient.Builder to use for creating the WebClient + * instance + * @param jsonMapper the ObjectMapper to use for JSON processing + * @param sseEndpoint the SSE endpoint URI to use for establishing the connection + * @param messageEndpointValidator validator for the message endpoint + * @throws IllegalArgumentException if either parameter is null + */ + public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, McpJsonMapper jsonMapper, String sseEndpoint, + SseMessageEndpointValidator messageEndpointValidator) { Assert.notNull(jsonMapper, "jsonMapper must not be null"); Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); + Assert.notNull(messageEndpointValidator, "messageEndpointValidator must not be null"); Assert.hasText(sseEndpoint, "SSE endpoint must not be null or empty"); this.jsonMapper = jsonMapper; this.webClient = webClientBuilder.build(); this.sseEndpoint = sseEndpoint; + this.messageEndpointValidator = messageEndpointValidator; } @Override @@ -195,6 +224,14 @@ public Mono connect(Function, Mono> h this.inboundSubscription = events.concatMap(event -> Mono.just(event).handle((e, s) -> { if (ENDPOINT_EVENT_TYPE.equals(event.event())) { String messageEndpointUri = event.data(); + try { + this.messageEndpointValidator.validate(this.sseUri.get(), messageEndpointUri); + } + catch (InvalidSseMessageEndpointException ex) { + messageEndpointSink.tryEmitError(ex); + s.error(ex); + return; + } if (messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) { s.complete(); } @@ -276,16 +313,17 @@ public Mono sendMessage(JSONRPCMessage message) { * Includes automatic retry logic for handling transient connection failures. */ // visible for tests - protected Flux> eventStream() {// @formatter:off - return this.webClient - .get() + protected Flux> eventStream() { + return this.webClient.get() .uri(this.sseEndpoint) .accept(MediaType.TEXT_EVENT_STREAM) .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) - .retrieve() - .bodyToFlux(SSE_TYPE) + .exchangeToFlux(exchange -> { + this.sseUri.set(exchange.request().getURI()); + return exchange.bodyToFlux(SSE_TYPE); + }) .retryWhen(Retry.from(retrySignal -> retrySignal.handle(inboundRetryHandler))); - } // @formatter:on + } /** * Retry handler for the inbound SSE stream. Implements the retry logic for handling @@ -368,6 +406,8 @@ public static class Builder { private McpJsonMapper jsonMapper; + private SseMessageEndpointValidator messageEndpointValidator = new DefaultSseMessageEndpointValidator(); + /** * Creates a new builder with the specified WebClient.Builder. * @param webClientBuilder the WebClient.Builder to use @@ -399,13 +439,26 @@ public Builder jsonMapper(McpJsonMapper jsonMapper) { return this; } + /** + * Sets the validator that ensure the message endpoint returned over the SSE + * connection is valid. + * @param messageEndpointValidator the validator + * @return this builder + */ + public Builder messageEndpointValidator(SseMessageEndpointValidator messageEndpointValidator) { + Assert.notNull(messageEndpointValidator, "messageEndpointValidator must not be null"); + this.messageEndpointValidator = messageEndpointValidator; + return this; + } + /** * Builds a new {@link WebFluxSseClientTransport} instance. * @return a new transport instance */ public WebFluxSseClientTransport build() { return new WebFluxSseClientTransport(webClientBuilder, - jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, sseEndpoint); + jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, sseEndpoint, + messageEndpointValidator); } } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java index 4b0d4e556..d109a32b4 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.client.transport; +import java.net.URI; import java.time.Duration; import java.util.Map; import java.util.concurrent.CopyOnWriteArrayList; @@ -21,6 +22,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; +import org.mockito.ArgumentCaptor; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; import reactor.core.publisher.Flux; @@ -35,6 +37,9 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.matches; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * Tests for the {@link WebFluxSseClientTransport} class. @@ -57,6 +62,8 @@ class WebFluxSseClientTransportTests { private WebClient.Builder webClientBuilder; + private SseMessageEndpointValidator sseMessageEndpointValidator = mock(SseMessageEndpointValidator.class); + // Test class to access protected methods static class TestSseClientTransport extends WebFluxSseClientTransport { @@ -64,8 +71,9 @@ static class TestSseClientTransport extends WebFluxSseClientTransport { private Sinks.Many> events = Sinks.many().unicast().onBackpressureBuffer(); - public TestSseClientTransport(WebClient.Builder webClientBuilder, McpJsonMapper jsonMapper) { - super(webClientBuilder, jsonMapper); + public TestSseClientTransport(WebClient.Builder webClientBuilder, McpJsonMapper jsonMapper, + SseMessageEndpointValidator sseMessageEndpointValidator) { + super(webClientBuilder, jsonMapper, "/sse", sseMessageEndpointValidator); } @Override @@ -113,7 +121,7 @@ static void cleanup() { @BeforeEach void setUp() { webClientBuilder = WebClient.builder().baseUrl(host); - transport = new TestSseClientTransport(webClientBuilder, JSON_MAPPER); + transport = new TestSseClientTransport(webClientBuilder, JSON_MAPPER, sseMessageEndpointValidator); transport.connect(Function.identity()).block(); } @@ -368,4 +376,44 @@ void testMessageOrderPreservation() { assertThat(transport.getInboundMessageCount()).isEqualTo(3); } + @Test + void testMessageEndpointValidation() throws InvalidSseMessageEndpointException { + var uriCaptor = ArgumentCaptor.forClass(URI.class); + verify(sseMessageEndpointValidator).validate(uriCaptor.capture(), matches("/message\\?sessionId=[a-z0-9-]+")); + assertThat(uriCaptor.getValue().toString()).matches("http://localhost:\\d+/sse"); + } + + @Test + void testMessageEndpointValidationRejects() { + TestSseClientTransport transport = new TestSseClientTransport(webClientBuilder, JSON_MAPPER, + (sseUri, messageEndpoint) -> { + throw new InvalidSseMessageEndpointException("boom", messageEndpoint); + }); + + try { + // fails to connect + StepVerifier.create(transport.connect(Function.identity())) + .verifyErrorMatches(WebFluxSseClientTransportTests::isInvalidEndpointError); + + // Since connection failed, there is no message endpoint, and no message can + // be sent + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + Map.of("key", "value")); + + StepVerifier.create(transport.sendMessage(testMessage)) + .verifyErrorMatches(WebFluxSseClientTransportTests::isInvalidEndpointError); + } + finally { + transport.closeGracefully(); + } + } + + private static boolean isInvalidEndpointError(Throwable e) { + if (e instanceof InvalidSseMessageEndpointException ismee) { + return ismee.getMessageEndpoint().matches("/message\\?sessionId=[a-z0-9-]+") + && ismee.getMessage().equals("boom"); + } + return false; + } + } diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index a24805a30..9b065f0c9 100644 --- a/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -19,7 +19,6 @@ import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; - import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; @@ -35,13 +34,13 @@ import org.springframework.http.codec.ServerSentEvent; import org.springframework.web.util.UriComponentsBuilder; - import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.ArgumentMatchers.matches; import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -66,6 +65,8 @@ class HttpClientSseClientTransportTests { private TestHttpClientSseClientTransport transport; + private SseMessageEndpointValidator sseMessageEndpointValidator = mock(SseMessageEndpointValidator.class); + private final McpTransportContext context = McpTransportContext.create(Map.of("some-key", "some-value")); // Test class to access protected methods @@ -75,10 +76,11 @@ static class TestHttpClientSseClientTransport extends HttpClientSseClientTranspo private Sinks.Many> events = Sinks.many().unicast().onBackpressureBuffer(); - public TestHttpClientSseClientTransport(final String baseUri) { + public TestHttpClientSseClientTransport(final String baseUri, + SseMessageEndpointValidator sseMessageEndpointValidator) { super(HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1).build(), HttpRequest.newBuilder().header("Content-Type", "application/json"), baseUri, "/sse", JSON_MAPPER, - McpAsyncHttpClientRequestCustomizer.NOOP); + McpAsyncHttpClientRequestCustomizer.NOOP, sseMessageEndpointValidator); } public int getInboundMessageCount() { @@ -112,7 +114,7 @@ static void stopContainer() { @BeforeEach void setUp() { - transport = new TestHttpClientSseClientTransport(host); + transport = new TestHttpClientSseClientTransport(host, sseMessageEndpointValidator); transport.connect(Function.identity()).block(); } @@ -477,4 +479,44 @@ void testAsyncRequestCustomizer() { customizedTransport.closeGracefully().block(); } + @Test + void testMessageEndpointValidation() throws InvalidSseMessageEndpointException { + var uriCaptor = ArgumentCaptor.forClass(URI.class); + verify(sseMessageEndpointValidator).validate(uriCaptor.capture(), matches("/message\\?sessionId=[a-z0-9-]+")); + assertThat(uriCaptor.getValue().toString()).matches("http://localhost:\\d+/sse"); + } + + @Test + void testMessageEndpointValidationRejects() { + TestHttpClientSseClientTransport transport = new TestHttpClientSseClientTransport(host, + (sseUri, messageEndpoint) -> { + throw new InvalidSseMessageEndpointException("boom", messageEndpoint); + }); + + try { + // fails to connect + StepVerifier.create(transport.connect(Function.identity())) + .verifyErrorMatches(HttpClientSseClientTransportTests::isInvalidEndpointError); + + // Since connection failed, there is no message endpoint, and no message can + // be sent + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + Map.of("key", "value")); + + StepVerifier.create(transport.sendMessage(testMessage)) + .verifyErrorMatches(HttpClientSseClientTransportTests::isInvalidEndpointError); + } + finally { + transport.closeGracefully(); + } + } + + private static boolean isInvalidEndpointError(Throwable e) { + if (e instanceof InvalidSseMessageEndpointException ismee) { + return ismee.getMessageEndpoint().matches("/message\\?sessionId=[a-z0-9-]+") + && ismee.getMessage().equals("boom"); + } + return false; + } + }