diff --git a/httpclient5/src/main/java/org/apache/hc/client5/http/impl/ProtocolSwitchStrategy.java b/httpclient5/src/main/java/org/apache/hc/client5/http/impl/ProtocolSwitchStrategy.java index 3babea8def..606e20d1a6 100644 --- a/httpclient5/src/main/java/org/apache/hc/client5/http/impl/ProtocolSwitchStrategy.java +++ b/httpclient5/src/main/java/org/apache/hc/client5/http/impl/ProtocolSwitchStrategy.java @@ -26,6 +26,7 @@ */ package org.apache.hc.client5.http.impl; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import org.apache.hc.core5.annotation.Internal; @@ -70,6 +71,10 @@ public InternalProtocolException(final ProtocolException cause) { } public ProtocolVersion switchProtocol(final HttpMessage response) throws ProtocolException { + if (!containsConnectionUpgrade(response)) { + throw new ProtocolException("Invalid protocol switch response: missing Connection: Upgrade"); + } + final AtomicReference tlsUpgrade = new AtomicReference<>(); parseHeaders(response, HttpHeaders.UPGRADE, (buffer, cursor) -> { @@ -91,6 +96,16 @@ public ProtocolVersion switchProtocol(final HttpMessage response) throws Protoco } } + private boolean containsConnectionUpgrade(final HttpMessage message) { + final AtomicBoolean found = new AtomicBoolean(false); + MessageSupport.parseTokens(message, HttpHeaders.CONNECTION, token -> { + if ("upgrade".equalsIgnoreCase(token)) { + found.set(true); + } + }); + return found.get(); + } + private ProtocolVersion parseProtocolVersion(final CharSequence buffer, final ParserCursor cursor) throws ProtocolException { TOKENIZER.skipWhiteSpace(buffer, cursor); final String proto = TOKENIZER.parseToken(buffer, cursor, LAX_PROTO_DELIMITER); diff --git a/httpclient5/src/test/java/org/apache/hc/client5/http/impl/TestProtocolSwitchStrategy.java b/httpclient5/src/test/java/org/apache/hc/client5/http/impl/TestProtocolSwitchStrategy.java index 9758607353..43878525b4 100644 --- a/httpclient5/src/test/java/org/apache/hc/client5/http/impl/TestProtocolSwitchStrategy.java +++ b/httpclient5/src/test/java/org/apache/hc/client5/http/impl/TestProtocolSwitchStrategy.java @@ -52,10 +52,12 @@ void setUp() { @Test void testSwitchToTLS() throws Exception { final HttpResponse response1 = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response1.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response1.addHeader(HttpHeaders.UPGRADE, "TLS"); Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response1)); final HttpResponse response2 = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response2.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response2.addHeader(HttpHeaders.UPGRADE, "TLS/1.3"); Assertions.assertEquals(TLS.V_1_3.getVersion(), switchStrategy.switchProtocol(response2)); } @@ -63,19 +65,23 @@ void testSwitchToTLS() throws Exception { @Test void testSwitchToHTTP11AndTLS() throws Exception { final HttpResponse response1 = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response1.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response1.addHeader(HttpHeaders.UPGRADE, "TLS, HTTP/1.1"); Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response1)); final HttpResponse response2 = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response2.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response2.addHeader(HttpHeaders.UPGRADE, ",, HTTP/1.1, TLS, "); Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response2)); final HttpResponse response3 = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response3.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response3.addHeader(HttpHeaders.UPGRADE, "HTTP/1.1"); response3.addHeader(HttpHeaders.UPGRADE, "TLS/1.2"); Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response3)); final HttpResponse response4 = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response4.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response4.addHeader(HttpHeaders.UPGRADE, "HTTP/1.1"); response4.addHeader(HttpHeaders.UPGRADE, "TLS/1.2, TLS/1.3"); Assertions.assertEquals(TLS.V_1_3.getVersion(), switchStrategy.switchProtocol(response4)); @@ -84,14 +90,17 @@ void testSwitchToHTTP11AndTLS() throws Exception { @Test void testSwitchInvalid() { final HttpResponse response1 = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response1.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response1.addHeader(HttpHeaders.UPGRADE, "Crap"); Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response1)); final HttpResponse response2 = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response2.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response2.addHeader(HttpHeaders.UPGRADE, "TLS, huh?"); Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response2)); final HttpResponse response3 = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response3.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response3.addHeader(HttpHeaders.UPGRADE, ",,,"); Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response3)); } @@ -99,6 +108,7 @@ void testSwitchInvalid() { @Test void testWhitespaceOnlyToken() throws ProtocolException { final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response.addHeader(HttpHeaders.UPGRADE, " , TLS"); Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response)); } @@ -106,6 +116,7 @@ void testWhitespaceOnlyToken() throws ProtocolException { @Test void testUnsupportedTlsVersion() throws Exception { final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response.addHeader(HttpHeaders.UPGRADE, "TLS/1.4"); Assertions.assertEquals(new ProtocolVersion("TLS", 1, 4), switchStrategy.switchProtocol(response)); } @@ -113,6 +124,7 @@ void testUnsupportedTlsVersion() throws Exception { @Test void testUnsupportedTlsMajorVersion() throws Exception { final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response.addHeader(HttpHeaders.UPGRADE, "TLS/2.0"); Assertions.assertEquals(new ProtocolVersion("TLS", 2, 0), switchStrategy.switchProtocol(response)); } @@ -120,6 +132,7 @@ void testUnsupportedTlsMajorVersion() throws Exception { @Test void testUnsupportedHttpVersion() { final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response.addHeader(HttpHeaders.UPGRADE, "HTTP/2.0"); final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response)); @@ -129,6 +142,7 @@ void testUnsupportedHttpVersion() { @Test void testInvalidTlsFormat() { final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response.addHeader(HttpHeaders.UPGRADE, "TLS/abc"); final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response)); @@ -138,6 +152,7 @@ void testInvalidTlsFormat() { @Test void testHttp11Only() { final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response.addHeader(HttpHeaders.UPGRADE, "HTTP/1.1"); final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response)); @@ -147,6 +162,7 @@ void testHttp11Only() { @Test void testSwitchToTlsValid_TLS_1_2() throws Exception { final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response.addHeader(HttpHeaders.UPGRADE, "TLS/1.2"); final ProtocolVersion result = switchStrategy.switchProtocol(response); Assertions.assertEquals(TLS.V_1_2.getVersion(), result); @@ -155,6 +171,7 @@ void testSwitchToTlsValid_TLS_1_2() throws Exception { @Test void testSwitchToTlsValid_TLS_1_0() throws Exception { final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response.addHeader(HttpHeaders.UPGRADE, "TLS/1.0"); final ProtocolVersion result = switchStrategy.switchProtocol(response); Assertions.assertEquals(TLS.V_1_0.getVersion(), result); @@ -163,6 +180,7 @@ void testSwitchToTlsValid_TLS_1_0() throws Exception { @Test void testSwitchToTlsValid_TLS_1_1() throws Exception { final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response.addHeader(HttpHeaders.UPGRADE, "TLS/1.1"); final ProtocolVersion result = switchStrategy.switchProtocol(response); Assertions.assertEquals(TLS.V_1_1.getVersion(), result); @@ -171,6 +189,7 @@ void testSwitchToTlsValid_TLS_1_1() throws Exception { @Test void testInvalidTlsFormat_NoSlash() { final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response.addHeader(HttpHeaders.UPGRADE, "TLSv1"); final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response)); @@ -180,6 +199,7 @@ void testInvalidTlsFormat_NoSlash() { @Test void testSwitchToTlsValid_TLS_1() throws Exception { final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response.addHeader(HttpHeaders.UPGRADE, "TLS/1"); final ProtocolVersion result = switchStrategy.switchProtocol(response); Assertions.assertEquals(TLS.V_1_0.getVersion(), result); @@ -188,6 +208,7 @@ void testSwitchToTlsValid_TLS_1() throws Exception { @Test void testInvalidTlsFormat_MissingMajor() { final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response.addHeader(HttpHeaders.UPGRADE, "TLS/.1"); final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response)); @@ -197,6 +218,7 @@ void testInvalidTlsFormat_MissingMajor() { @Test void testMultipleHttp11Tokens() { final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response.addHeader(HttpHeaders.UPGRADE, "HTTP/1.1, HTTP/1.1"); final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response)); @@ -206,6 +228,7 @@ void testMultipleHttp11Tokens() { @Test void testMixedInvalidAndValidTokens() { final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response.addHeader(HttpHeaders.UPGRADE, "Crap, TLS/1.2, Invalid"); final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response)); @@ -215,10 +238,41 @@ void testMixedInvalidAndValidTokens() { @Test void testInvalidTlsFormat_NoProtocolName() { final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.CONNECTION, "Upgrade"); response.addHeader(HttpHeaders.UPGRADE, ",,/1.1"); final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response)); Assertions.assertEquals("Invalid protocol; error at offset 2: <,,/1.1>", ex.getMessage()); } + @Test + void testMissingConnectionUpgradeRejected() { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "TLS/1.2"); + + final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () -> + switchStrategy.switchProtocol(response)); + Assertions.assertEquals("Invalid protocol switch response: missing Connection: Upgrade", ex.getMessage()); + } + + @Test + void testConnectionWithoutUpgradeTokenRejected() { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.CONNECTION, "keep-alive"); + response.addHeader(HttpHeaders.UPGRADE, "TLS/1.2"); + + final ProtocolException ex = Assertions.assertThrows(ProtocolException.class, () -> + switchStrategy.switchProtocol(response)); + Assertions.assertEquals("Invalid protocol switch response: missing Connection: Upgrade", ex.getMessage()); + } + + @Test + void testConnectionWithUpgradeTokenInListAccepted() throws Exception { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.CONNECTION, "keep-alive, Upgrade"); + response.addHeader(HttpHeaders.UPGRADE, "TLS/1.2"); + + Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response)); + } + } \ No newline at end of file