diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f86f69e38..9d5e891243 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ * [FEATURE] Querier: Implement Resource Based Throttling in Querier. #7442 * [ENHANCEMENT] Upgrade prometheus alertmanager version to v0.32.1. #7462 * [ENHANCEMENT] Tenant Federation: Avoid purging the regex resolver LRU cache on user-sync ticks when the set of known users has not changed. #7489 +* [ENHANCEMENT] Memberlist: Add `-memberlist.packet-read-timeout`, `-memberlist.max-packet-size`, and `-memberlist.max-concurrent-connections` flags to bound inbound gossip TCP connections, preventing slow-read, OOM, and connection-flood attacks on the gossip port. #7518 * [ENHANCEMENT] Parquet Converter: Add a ring status page to expose the ring status. #7455 * [ENHANCEMENT] Ingester: Add WAL record metrics to help evaluate the effectiveness of WAL compression type (e.g. snappy, zstd): `cortex_ingester_tsdb_wal_record_part_writes_total`, `cortex_ingester_tsdb_wal_record_parts_bytes_written_total`, and `cortex_ingester_tsdb_wal_record_bytes_saved_total`. #7420 * [ENHANCEMENT] Distributor: Introduce dynamic `Symbols` slice capacity pooling. #7398 #7401 diff --git a/docs/configuration/config-file-reference.md b/docs/configuration/config-file-reference.md index a605dd798b..b71490c2d9 100644 --- a/docs/configuration/config-file-reference.md +++ b/docs/configuration/config-file-reference.md @@ -4693,6 +4693,18 @@ The `memberlist_config` configures the Gossip memberlist. # CLI flag: -memberlist.packet-write-timeout [packet_write_timeout: | default = 5s] +# Timeout for reading packet data from inbound connections. 0 = no limit. +# CLI flag: -memberlist.packet-read-timeout +[packet_read_timeout: | default = 5s] + +# Maximum size in bytes of an inbound gossip packet. 0 = no limit. +# CLI flag: -memberlist.max-packet-size +[max_packet_size: | default = 1048576] + +# Maximum number of concurrent inbound TCP connections. 0 = no limit. +# CLI flag: -memberlist.max-concurrent-connections +[max_concurrent_connections: | default = 100] + # Enable TLS on the memberlist transport layer. # CLI flag: -memberlist.tls-enabled [tls_enabled: | default = false] diff --git a/pkg/ring/kv/memberlist/tcp_transport.go b/pkg/ring/kv/memberlist/tcp_transport.go index cc461cd63b..a5e7e61b15 100644 --- a/pkg/ring/kv/memberlist/tcp_transport.go +++ b/pkg/ring/kv/memberlist/tcp_transport.go @@ -20,6 +20,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "go.uber.org/atomic" + "golang.org/x/sync/semaphore" "github.com/cortexproject/cortex/pkg/util/flagext" cortextls "github.com/cortexproject/cortex/pkg/util/tls" @@ -50,6 +51,15 @@ type TCPTransportConfig struct { // Timeout for writing packet data. Zero = no timeout. PacketWriteTimeout time.Duration `yaml:"packet_write_timeout"` + // Timeout for reading inbound packet data. Zero = no timeout. + PacketReadTimeout time.Duration `yaml:"packet_read_timeout"` + + // Maximum size in bytes of a single inbound packet. Zero = no limit. + MaxPacketSize int64 `yaml:"max_packet_size"` + + // Maximum number of concurrent inbound TCP connections. Zero = no limit. + MaxConcurrentConnections int `yaml:"max_concurrent_connections"` + // Transport logs lot of messages at debug level, so it deserves an extra flag for turning it on TransportDebug bool `yaml:"-"` @@ -72,6 +82,9 @@ func (cfg *TCPTransportConfig) RegisterFlagsWithPrefix(f *flag.FlagSet, prefix s f.IntVar(&cfg.BindPort, prefix+"memberlist.bind-port", 7946, "Port to listen on for gossip messages.") f.DurationVar(&cfg.PacketDialTimeout, prefix+"memberlist.packet-dial-timeout", 5*time.Second, "Timeout used when connecting to other nodes to send packet.") f.DurationVar(&cfg.PacketWriteTimeout, prefix+"memberlist.packet-write-timeout", 5*time.Second, "Timeout for writing 'packet' data.") + f.DurationVar(&cfg.PacketReadTimeout, prefix+"memberlist.packet-read-timeout", 5*time.Second, "Timeout for reading packet data from inbound connections. 0 = no limit.") + f.Int64Var(&cfg.MaxPacketSize, prefix+"memberlist.max-packet-size", 1*1024*1024 /*1MB*/, "Maximum size in bytes of an inbound gossip packet. 0 = no limit.") + f.IntVar(&cfg.MaxConcurrentConnections, prefix+"memberlist.max-concurrent-connections", 100, "Maximum number of concurrent inbound TCP connections. 0 = no limit.") f.BoolVar(&cfg.TransportDebug, prefix+"memberlist.transport-debug", false, "Log debug transport messages. Note: global log.level must be at debug level as well.") f.BoolVar(&cfg.TLSEnabled, prefix+"memberlist.tls-enabled", false, "Enable TLS on the memberlist transport layer.") @@ -90,6 +103,9 @@ type TCPTransport struct { tcpListeners []net.Listener tlsConfig *tls.Config + // connSemaphore limits the number of concurrent inbound TCP connections. + connSemaphore *semaphore.Weighted + shutdown atomic.Int32 advertiseMu sync.RWMutex @@ -107,6 +123,9 @@ type TCPTransport struct { sentPacketsBytes prometheus.Counter sentPacketsErrors prometheus.Counter unknownConnections prometheus.Counter + rejectedConnections prometheus.Counter + activeConnections prometheus.Gauge + packetReceiveDuration prometheus.Histogram } // NewTCPTransport returns a new tcp-based transport with the given configuration. On @@ -125,6 +144,10 @@ func NewTCPTransport(config TCPTransportConfig, logger log.Logger) (*TCPTranspor connCh: make(chan net.Conn), } + if config.MaxConcurrentConnections > 0 { + t.connSemaphore = semaphore.NewWeighted(int64(config.MaxConcurrentConnections)) + } + var err error if config.TLSEnabled { t.tlsConfig, err = config.TLS.GetTLSConfig() @@ -222,7 +245,27 @@ func (t *TCPTransport) tcpListen(tcpLn net.Listener) { // No error, reset loop delay loopDelay = 0 - go t.handleConnection(conn) + // Enforce concurrent connection via semaphore. + if t.connSemaphore != nil { + if !t.connSemaphore.TryAcquire(1) { + t.rejectedConnections.Inc() + level.Debug(t.logger).Log("msg", "max concurrent connections reached, closing connection", "remote", conn.RemoteAddr()) + _ = conn.Close() + continue + } + } + + t.activeConnections.Inc() + go func() { + // handleConnection returns true when it wrapped the conn in a + // semaphoreConn and transferred ownership of the slot to that + // wrapper (stream path). In that case we must not release here. + semTransferred := t.handleConnection(conn) + if t.connSemaphore != nil && !semTransferred { + t.connSemaphore.Release(1) + } + t.activeConnections.Dec() + }() } } @@ -235,7 +278,7 @@ func (t *TCPTransport) debugLog() log.Logger { return noopLogger } -func (t *TCPTransport) handleConnection(conn net.Conn) { +func (t *TCPTransport) handleConnection(conn net.Conn) (semTransferred bool) { t.debugLog().Log("msg", "New connection", "addr", conn.RemoteAddr()) closeConn := true @@ -245,6 +288,15 @@ func (t *TCPTransport) handleConnection(conn net.Conn) { } }() + // Apply a read deadline for the entire packet receive so that a slow or + // adversarial peer cannot hold the goroutine open indefinitely. + if t.cfg.PacketReadTimeout > 0 { + if err := conn.SetReadDeadline(time.Now().Add(t.cfg.PacketReadTimeout)); err != nil { + level.Warn(t.logger).Log("msg", "failed to set read deadline", "err", err, "remote", conn.RemoteAddr()) + return + } + } + // let's read first byte, and determine what to do about this connection msgType := []byte{0} _, err := io.ReadFull(conn, msgType) @@ -256,13 +308,28 @@ func (t *TCPTransport) handleConnection(conn net.Conn) { if messageType(msgType[0]) == stream { t.incomingStreams.Inc() - // hand over this connection to memberlist + // Stream connections are handed off to memberlist which manages them + // independently – clear the deadline so memberlist can use its own + // timeouts, then pass the connection over. + if t.cfg.PacketReadTimeout > 0 { + _ = conn.SetReadDeadline(time.Time{}) + } + + // hand over this connection to memberlist. + // If the semaphore is active, wrap the conn so that the slot is held + // for the real lifetime of the stream. The memberlist will close it. closeConn = false - t.connCh <- conn + if t.connSemaphore != nil { + t.connCh <- &semaphoreConn{Conn: conn, sem: t.connSemaphore} + semTransferred = true + } else { + t.connCh <- conn + } } else if messageType(msgType[0]) == packet { // it's a memberlist "packet", which contains an address and data. t.receivedPackets.Inc() + packetStart := time.Now() // before reading packet, read the address addrLengthBuf := []byte{0} _, err := io.ReadFull(conn, addrLengthBuf) @@ -280,14 +347,26 @@ func (t *TCPTransport) handleConnection(conn net.Conn) { return } - // read the rest to buffer -- this is the "packet" itself - buf, err := io.ReadAll(conn) + var reader io.Reader = conn + if t.cfg.MaxPacketSize > 0 { + // Read one byte beyond the limit so we can detect oversized packets. + reader = io.LimitReader(conn, t.cfg.MaxPacketSize+1) + } + buf, err := io.ReadAll(reader) + t.packetReceiveDuration.Observe(time.Since(packetStart).Seconds()) if err != nil { t.receivedPacketsErrors.Inc() level.Warn(t.logger).Log("msg", "error while reading packet data", "err", err, "remote", conn.RemoteAddr()) return } + // Reject oversized packets + if t.cfg.MaxPacketSize > 0 && int64(len(buf)) > t.cfg.MaxPacketSize { + t.receivedPacketsErrors.Inc() + level.Debug(t.logger).Log("msg", "packet too large, dropping", "size", len(buf), "max", t.cfg.MaxPacketSize, "remote", conn.RemoteAddr()) + return + } + if len(buf) < md5.Size { t.receivedPacketsErrors.Inc() level.Warn(t.logger).Log("msg", "not enough data received", "data_length", len(buf), "remote", conn.RemoteAddr()) @@ -318,6 +397,7 @@ func (t *TCPTransport) handleConnection(conn net.Conn) { t.unknownConnections.Inc() level.Error(t.logger).Log("msg", "unknown message type", "msgType", msgType, "remote", conn.RemoteAddr()) } + return } type addr string @@ -330,6 +410,20 @@ func (a addr) String() string { return string(a) } +// semaphoreConn wraps a net.Conn and releases a semaphore slot exactly once +// when the connection is closed. It is used on the stream path to keep the +// concurrent-connection slot held for the real lifetime of the connection. +type semaphoreConn struct { + net.Conn + sem *semaphore.Weighted + once sync.Once +} + +func (c *semaphoreConn) Close() error { + c.once.Do(func() { c.sem.Release(1) }) + return c.Conn.Close() +} + func (t *TCPTransport) getConnection(addr string, timeout time.Duration) (net.Conn, error) { if t.cfg.TLSEnabled { return tls.DialWithDialer(&net.Dialer{Timeout: timeout}, "tcp", addr, t.tlsConfig) @@ -634,4 +728,29 @@ func (t *TCPTransport) registerMetrics(registerer prometheus.Registerer) { Name: "unknown_connections_total", Help: "Number of unknown TCP connections (not a packet or stream)", }) + + t.rejectedConnections = promauto.With(registerer).NewCounter(prometheus.CounterOpts{ + Namespace: t.cfg.MetricsNamespace, + Subsystem: subsystem, + Name: "rejected_connections_total", + Help: "Number of inbound TCP connections rejected because the concurrent connection limit was reached", + }) + + t.activeConnections = promauto.With(registerer).NewGauge(prometheus.GaugeOpts{ + Namespace: t.cfg.MetricsNamespace, + Subsystem: subsystem, + Name: "active_connections", + Help: "Current number of active inbound TCP connections.", + }) + + t.packetReceiveDuration = promauto.With(registerer).NewHistogram(prometheus.HistogramOpts{ + Namespace: t.cfg.MetricsNamespace, + Subsystem: subsystem, + Name: "packet_receive_duration_seconds", + Help: "Duration (in seconds) of inbound packet-type message reads.", + Buckets: prometheus.DefBuckets, + NativeHistogramBucketFactor: 1.1, + NativeHistogramMaxBucketNumber: 100, + NativeHistogramMinResetDuration: 1 * time.Hour, + }) } diff --git a/pkg/ring/kv/memberlist/tcp_transport_test.go b/pkg/ring/kv/memberlist/tcp_transport_test.go index 5f154bd4d8..6a3445417e 100644 --- a/pkg/ring/kv/memberlist/tcp_transport_test.go +++ b/pkg/ring/kv/memberlist/tcp_transport_test.go @@ -5,10 +5,12 @@ import ( "crypto/md5" "fmt" "net" + "sync" "testing" "time" "github.com/go-kit/log" + "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -116,4 +118,247 @@ func TestTCPTransport_PacketDigestMismatch(t *testing.T) { } assert.Contains(t, logs.String(), "packet digest mismatch") + + require.Eventually(t, func() bool { + return testutil.ToFloat64(transport.activeConnections) == 0 + }, 2*time.Second, 10*time.Millisecond, "activeConnections should be back to 0 after digest mismatch") +} + +func TestTCPTransport_PacketReadTimeout(t *testing.T) { + logger := log.NewNopLogger() + + cfg := TCPTransportConfig{} + flagext.DefaultValues(&cfg) + cfg.BindAddrs = []string{"127.0.0.1"} + cfg.BindPort = 0 + cfg.PacketReadTimeout = 200 * time.Millisecond + + transport, err := NewTCPTransport(cfg, logger) + require.NoError(t, err) + defer transport.Shutdown() //nolint:errcheck + + port := transport.GetAutoBindPort() + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + require.NoError(t, err) + defer conn.Close() //nolint:errcheck + + // Send packet type byte and address header, then stall – never send payload. + ourAddr := "127.0.0.1:0" + var buf bytes.Buffer + buf.WriteByte(byte(packet)) + buf.WriteByte(byte(len(ourAddr))) + buf.WriteString(ourAddr) + _, err = conn.Write(buf.Bytes()) + require.NoError(t, err) + + // The transport should close the connection after PacketReadTimeout. + // We verify this by trying to read from the conn; once the server side + // closes it due to the deadline, our Read should return an error. + conn.SetReadDeadline(time.Now().Add(2 * time.Second)) //nolint:errcheck + oneByte := make([]byte, 1) + _, readErr := conn.Read(oneByte) + assert.Error(t, readErr, "expected connection to be closed by server after read timeout") + + require.Eventually(t, func() bool { + return testutil.ToFloat64(transport.activeConnections) == 0 + }, 2*time.Second, 10*time.Millisecond, "activeConnections should be back to 0 after read timeout") +} + +func TestTCPTransport_MaxPacketSize(t *testing.T) { + logs := &concurrency.SyncBuffer{} + logger := log.NewLogfmtLogger(logs) + + cfg := TCPTransportConfig{} + flagext.DefaultValues(&cfg) + cfg.BindAddrs = []string{"127.0.0.1"} + cfg.BindPort = 0 + cfg.MaxPacketSize = 128 + + transport, err := NewTCPTransport(cfg, logger) + require.NoError(t, err) + defer transport.Shutdown() //nolint:errcheck + + port := transport.GetAutoBindPort() + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + require.NoError(t, err) + defer conn.Close() //nolint:errcheck + + // Build a packet that exceeds MaxPacketSize. + ourAddr := "127.0.0.1:0" + oversizedData := make([]byte, int(cfg.MaxPacketSize)+64) + digest := md5.Sum(oversizedData) + + var buf bytes.Buffer + buf.WriteByte(byte(packet)) + buf.WriteByte(byte(len(ourAddr))) + buf.WriteString(ourAddr) + buf.Write(oversizedData) + buf.Write(digest[:]) + + _, err = conn.Write(buf.Bytes()) + require.NoError(t, err) + conn.Close() //nolint:errcheck + + // Packet should be dropped; nothing must arrive on packetCh. + select { + case <-transport.PacketCh(): + t.Fatal("oversized packet should have been dropped") + case <-time.After(500 * time.Millisecond): + // success + } + + assert.Contains(t, logs.String(), "packet too large") + + require.Eventually(t, func() bool { + return testutil.ToFloat64(transport.activeConnections) == 0 + }, 2*time.Second, 10*time.Millisecond, "activeConnections should be back to 0 after oversized packet") +} + +func TestTCPTransport_MaxConcurrentConnections(t *testing.T) { + logs := &concurrency.SyncBuffer{} + logger := log.NewLogfmtLogger(logs) + + const maxConns = 3 + + cfg := TCPTransportConfig{} + flagext.DefaultValues(&cfg) + cfg.BindAddrs = []string{"127.0.0.1"} + cfg.BindPort = 0 + cfg.PacketReadTimeout = 5 * time.Second + cfg.MaxConcurrentConnections = maxConns + + transport, err := NewTCPTransport(cfg, logger) + require.NoError(t, err) + defer transport.Shutdown() //nolint:errcheck + + port := transport.GetAutoBindPort() + + openSlowConn := func() net.Conn { + c, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + require.NoError(t, err) + // Send packet type byte to enter the packet branch, then stall. + _, err = c.Write([]byte{byte(packet)}) + require.NoError(t, err) + return c + } + + // Fill up the semaphore. + holders := make([]net.Conn, maxConns) + for i := range maxConns { + holders[i] = openSlowConn() + } + defer func() { + for _, c := range holders { + c.Close() //nolint:errcheck + } + }() + + require.Eventually(t, func() bool { + return testutil.ToFloat64(transport.receivedPackets) == float64(maxConns) + }, 2*time.Second, 10*time.Millisecond, "server never accepted %d connections", maxConns) + + // This extra connection should be rejected. + var wg sync.WaitGroup + wg.Go(func() { + extra, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + if err != nil { + return // connection may be refused outright + } + defer extra.Close() //nolint:errcheck + // Try to read; the server should close immediately. + extra.SetReadDeadline(time.Now().Add(time.Second)) //nolint:errcheck + buf := make([]byte, 1) + extra.Read(buf) //nolint:errcheck + }) + wg.Wait() + + assert.Contains(t, logs.String(), "max concurrent connections reached") + + assert.GreaterOrEqual(t, testutil.ToFloat64(transport.rejectedConnections), float64(1)) + assert.Equal(t, float64(maxConns), testutil.ToFloat64(transport.activeConnections)) +} + +// TestTCPTransport_StreamHoldsSlotUntilClose asserts that +// -memberlist.max-concurrent-connections bounds the number of *live* inbound +// TCP connections: once a stream conn has been handed off to memberlist via +// StreamCh(), its slot stays held until the conn is actually closed. +func TestTCPTransport_StreamHoldsSlotUntilClose(t *testing.T) { + logger := log.NewNopLogger() + + const maxConns = 2 + + cfg := TCPTransportConfig{} + flagext.DefaultValues(&cfg) + cfg.BindAddrs = []string{"127.0.0.1"} + cfg.BindPort = 0 + cfg.PacketReadTimeout = 5 * time.Second + cfg.MaxConcurrentConnections = maxConns + + transport, err := NewTCPTransport(cfg, logger) + require.NoError(t, err) + defer transport.Shutdown() //nolint:errcheck + + port := transport.GetAutoBindPort() + + // Consumer goroutine: drains StreamCh and holds conns alive (never closes + // them) — simulating memberlist actively using streams. + var heldMu sync.Mutex + var held []net.Conn + done := make(chan struct{}) + go func() { + for { + select { + case <-done: + return + case c := <-transport.StreamCh(): + heldMu.Lock() + held = append(held, c) + heldMu.Unlock() + } + } + }() + defer func() { + close(done) + heldMu.Lock() + for _, c := range held { + c.Close() //nolint:errcheck + } + heldMu.Unlock() + }() + + openStreamConn := func() net.Conn { + c, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + require.NoError(t, err) + _, err = c.Write([]byte{byte(stream)}) + require.NoError(t, err) + return c + } + + // Fill the semaphore with maxConns live stream handoffs. + clients := make([]net.Conn, 0, maxConns+1) + defer func() { + for _, c := range clients { + c.Close() //nolint:errcheck + } + }() + for range maxConns { + clients = append(clients, openStreamConn()) + } + + // Wait until memberlist side has observed all maxConns streams. + require.Eventually(t, func() bool { + return testutil.ToFloat64(transport.incomingStreams) == float64(maxConns) + }, 2*time.Second, 10*time.Millisecond) + + // One extra stream conn. If the slot is correctly held for the conn's + // real lifetime, the transport must reject this one because all slots + // are still occupied by the held streams above. + clients = append(clients, openStreamConn()) + + require.Eventually(t, func() bool { + return testutil.ToFloat64(transport.rejectedConnections) >= 1 + }, 2*time.Second, 10*time.Millisecond, + "expected extra stream conn to be rejected while %d prior streams are held open, "+ + "but the transport released the slot on handoff — flag does not bound live connections", + maxConns) } diff --git a/schemas/cortex-config-schema.json b/schemas/cortex-config-schema.json index 51aee2c0f5..e93e7c36a4 100644 --- a/schemas/cortex-config-schema.json +++ b/schemas/cortex-config-schema.json @@ -5728,6 +5728,12 @@ "x-cli-flag": "memberlist.left-ingesters-timeout", "x-format": "duration" }, + "max_concurrent_connections": { + "default": 100, + "description": "Maximum number of concurrent inbound TCP connections. 0 = no limit.", + "type": "number", + "x-cli-flag": "memberlist.max-concurrent-connections" + }, "max_join_backoff": { "default": "1m0s", "description": "Max backoff duration to join other cluster members.", @@ -5741,6 +5747,12 @@ "type": "number", "x-cli-flag": "memberlist.max-join-retries" }, + "max_packet_size": { + "default": 1048576, + "description": "Maximum size in bytes of an inbound gossip packet. 0 = no limit.", + "type": "number", + "x-cli-flag": "memberlist.max-packet-size" + }, "message_history_buffer_bytes": { "default": 0, "description": "How much space to use for keeping received and sent messages in memory for troubleshooting (two buffers). 0 to disable.", @@ -5766,6 +5778,13 @@ "x-cli-flag": "memberlist.packet-dial-timeout", "x-format": "duration" }, + "packet_read_timeout": { + "default": "5s", + "description": "Timeout for reading packet data from inbound connections. 0 = no limit.", + "type": "string", + "x-cli-flag": "memberlist.packet-read-timeout", + "x-format": "duration" + }, "packet_write_timeout": { "default": "5s", "description": "Timeout for writing 'packet' data.",