diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index aafff8f..67dfd25 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -5,6 +5,9 @@ import ( "context" "fmt" "io" + "math" + "math/rand" + "strconv" "strings" "time" @@ -193,6 +196,9 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) { minTimeToExpiry: bi.cfg.MinTimeToExpiry, speedThresholdMbps: bi.cfg.CloudFetchSpeedThresholdMbps, httpClient: bi.httpClient, + retryMax: bi.cfg.RetryMax, + retryWaitMin: bi.cfg.RetryWaitMin, + retryWaitMax: bi.cfg.RetryWaitMax, } task.Run() bi.downloadTasks.Enqueue(task) @@ -252,6 +258,9 @@ type cloudFetchDownloadTask struct { resultChan chan cloudFetchDownloadTaskResult speedThresholdMbps float64 httpClient *http.Client + retryMax int + retryWaitMin time.Duration + retryWaitMax time.Duration } func (cft *cloudFetchDownloadTask) GetResult() (io.Reader, int64, error) { @@ -295,20 +304,32 @@ func (cft *cloudFetchDownloadTask) Run() { cft.link.RowCount, ) downloadStart := time.Now() - data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps, cft.httpClient) + rawBody, err := fetchBatchBytes( + cft.ctx, + cft.link, + cft.minTimeToExpiry, + cft.speedThresholdMbps, + cft.httpClient, + cft.retryMax, + cft.retryWaitMin, + cft.retryWaitMax, + ) if err != nil { cft.sendResult(cloudFetchDownloadTaskResult{data: nil, err: err}) return } - // Read all data into memory before closing - buf, err := io.ReadAll(getReader(data, cft.useLz4Compression)) - data.Close() //nolint:errcheck,gosec // G104: close after reading data - downloadMs := time.Since(downloadStart).Milliseconds() - if err != nil { - cft.sendResult(cloudFetchDownloadTaskResult{data: nil, err: err}) - return + buf := rawBody + if cft.useLz4Compression { + // Decompression sits outside the retry loop: malformed LZ4 is data + // corruption, not a transient network condition. + buf, err = io.ReadAll(lz4.NewReader(bytes.NewReader(rawBody))) + if err != nil { + cft.sendResult(cloudFetchDownloadTaskResult{data: nil, err: err}) + return + } } + downloadMs := time.Since(downloadStart).Milliseconds() logger.Debug().Msgf( "CloudFetch: downloaded data for link at offset %d row count %d", @@ -350,43 +371,177 @@ func logCloudFetchSpeed(fullURL string, contentLength int64, duration time.Durat } } +// fetchBatchBytes downloads a single CloudFetch result link and returns the +// raw response body, still compressed if the server used LZ4. Connection-time +// failures, retryable HTTP statuses, and mid-stream body read failures are +// retried up to retryMax times with exponential backoff and equal jitter. +// Decompression and IPC parsing stay with the caller because those failures are +// not transient network conditions. +// +// Link expiry is rechecked after each backoff: a long retry chain may outlive +// a presigned URL, and continuing past expiry is guaranteed to fail. func fetchBatchBytes( ctx context.Context, link *cli_service.TSparkArrowResultLink, minTimeToExpiry time.Duration, speedThresholdMbps float64, httpClient *http.Client, -) (io.ReadCloser, error) { - if isLinkExpired(link.ExpiryTime, minTimeToExpiry) { - return nil, errors.New(dbsqlerr.ErrLinkExpired) + retryMax int, + retryWaitMin time.Duration, + retryWaitMax time.Duration, +) ([]byte, error) { + if retryMax < 0 { + retryMax = 0 + } + + var ( + lastErr error + lastStatus int + lastRetryAfter string + ) + + for attempt := 0; attempt <= retryMax; attempt++ { + if attempt > 0 { + wait := cloudFetchBackoff(attempt, retryWaitMin, retryWaitMax, lastRetryAfter) + logger.Debug().Msgf( + "CloudFetch: retrying download of link at offset %d (attempt %d/%d) in %v; lastStatus=%d lastErr=%v", + link.StartRowOffset, attempt, retryMax, wait, lastStatus, lastErr, + ) + t := time.NewTimer(wait) + select { + case <-ctx.Done(): + if !t.Stop() { + <-t.C + } + return nil, ctx.Err() + case <-t.C: + } + } + + // Check link expiry *after* backoff: a long retry chain may outlive a + // presigned URL, and there's no point spending another HTTP attempt + // (or another retry) on a link we know will be rejected. + if isLinkExpired(link.ExpiryTime, minTimeToExpiry) { + return nil, errors.New(dbsqlerr.ErrLinkExpired) + } + + req, err := http.NewRequestWithContext(ctx, "GET", link.FileLink, nil) + if err != nil { + return nil, err + } + if link.HttpHeaders != nil { + for key, value := range link.HttpHeaders { + req.Header.Set(key, value) + } + } + + startTime := time.Now() + res, err := httpClient.Do(req) + if err != nil { + // Caller cancellation is terminal; otherwise treat transport errors + // (TCP RST, TLS timeout, etc.) as transient. + if ctx.Err() != nil { + return nil, ctx.Err() + } + lastErr = err + lastStatus = 0 + lastRetryAfter = "" + continue + } + + if res.StatusCode == http.StatusOK { + // Read the full body inside the retry loop so truncated 200 OK + // responses are retried just like header-time failures. + buf, readErr := io.ReadAll(res.Body) + res.Body.Close() //nolint:errcheck,gosec // G104: close after drain + if readErr != nil { + if ctx.Err() != nil { + return nil, ctx.Err() + } + lastErr = readErr + lastStatus = 0 + lastRetryAfter = "" + continue + } + logCloudFetchSpeed(link.FileLink, int64(len(buf)), time.Since(startTime), speedThresholdMbps) + return buf, nil + } + + // Drain and close so the underlying connection can be reused. + _, _ = io.Copy(io.Discard, res.Body) + res.Body.Close() //nolint:errcheck,gosec // G104: closing after drain + + lastStatus = res.StatusCode + lastErr = nil + lastRetryAfter = res.Header.Get("Retry-After") + + if !isCloudFetchRetryableStatus(res.StatusCode) { + msg := fmt.Sprintf("%s: %s %d", errArrowRowsCloudFetchDownloadFailure, "HTTP error", res.StatusCode) + return nil, dbsqlerrint.NewDriverError(ctx, msg, nil) + } } - // TODO: Retry on HTTP errors - req, err := http.NewRequestWithContext(ctx, "GET", link.FileLink, nil) - if err != nil { - return nil, err + if lastStatus != 0 { + // lastErr is nil here by construction: the HTTP-status branch above + // explicitly clears it on every iteration. The status code is captured + // in msg, so there's no underlying error to wrap. + return nil, dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("%s: %s %d (after %d retries)", errArrowRowsCloudFetchDownloadFailure, "HTTP error", lastStatus, retryMax), nil) } + msg := fmt.Sprintf("%s: %v (after %d retries)", errArrowRowsCloudFetchDownloadFailure, lastErr, retryMax) + return nil, dbsqlerrint.NewDriverError(ctx, msg, lastErr) +} + +// cloudFetchRetryableStatuses lists HTTP status codes from object storage that +// indicate transient conditions and warrant a retry. Mirrors AWS S3 guidance +// for SlowDown (503) / InternalError (500) plus the general 408/429/502/504. +var cloudFetchRetryableStatuses = map[int]struct{}{ + http.StatusRequestTimeout: {}, // 408 + http.StatusTooManyRequests: {}, // 429 + http.StatusInternalServerError: {}, // 500 + http.StatusBadGateway: {}, // 502 + http.StatusServiceUnavailable: {}, // 503 + http.StatusGatewayTimeout: {}, // 504 +} - if link.HttpHeaders != nil { - for key, value := range link.HttpHeaders { - req.Header.Set(key, value) +func isCloudFetchRetryableStatus(status int) bool { + _, ok := cloudFetchRetryableStatuses[status] + return ok +} + +// cloudFetchBackoff returns the wait before retry attempt N (1-based). The +// base delay is exponential — waitMin * 2^(attempt-1) capped at waitMax — with +// equal jitter applied: the actual sleep is uniformly distributed in +// [base/2, base]. Equal jitter (rather than no jitter) is used to spread +// synchronized retries across the up-to-MaxDownloadThreads concurrent +// downloads, which would otherwise hammer the storage endpoint in lockstep +// after a region-wide blip. If the server returned a parseable integer +// Retry-After header, that value (in seconds) is honored instead, capped at +// waitMax. HTTP-date Retry-After values are ignored — same as the Thrift +// client's backoff. +func cloudFetchBackoff(attempt int, waitMin, waitMax time.Duration, retryAfter string) time.Duration { + if retryAfter != "" { + if secs, err := strconv.ParseInt(retryAfter, 10, 64); err == nil && secs >= 0 { + d := time.Duration(secs) * time.Second + if d > waitMax { + return waitMax + } + return d } } - startTime := time.Now() - res, err := httpClient.Do(req) - if err != nil { - return nil, err + expo := float64(waitMin) * math.Pow(2, float64(attempt-1)) + if expo > float64(waitMax) || math.IsInf(expo, 0) { + expo = float64(waitMax) } - if res.StatusCode != http.StatusOK { - msg := fmt.Sprintf("%s: %s %d", errArrowRowsCloudFetchDownloadFailure, "HTTP error", res.StatusCode) - return nil, dbsqlerrint.NewDriverError(ctx, msg, err) + base := time.Duration(expo) + if base <= 0 { + return 0 } - - // Log download speed metrics - logCloudFetchSpeed(link.FileLink, res.ContentLength, time.Since(startTime), speedThresholdMbps) - - return res.Body, nil + half := base / 2 + if half <= 0 { + return base + } + return half + time.Duration(rand.Int63n(int64(half))) //nolint:gosec // G404: jitter only, non-cryptographic } func getReader(r io.Reader, useLz4Compression bool) io.Reader { diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index 52e7dc4..d8d942b 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -33,6 +33,26 @@ func TestCloudFetchIterator(t *testing.T) { })) defer server.Close() + writeTruncatedOK := func(t *testing.T, w http.ResponseWriter, body []byte) { + t.Helper() + hj, ok := w.(http.Hijacker) + if !ok { + t.Errorf("ResponseWriter does not support Hijacker") + return + } + conn, bufrw, err := hj.Hijack() + if err != nil { + t.Errorf("hijack failed: %v", err) + return + } + _, _ = fmt.Fprintf(bufrw, "HTTP/1.1 200 OK\r\nContent-Length: 1000000\r\nConnection: close\r\n\r\n") + if len(body) > 0 { + _, _ = bufrw.Write(body) + } + _ = bufrw.Flush() + _ = conn.Close() + } + t.Run("should fetch all the links", func(t *testing.T) { cloudFetchHeaders := map[string]string{ "foo": "bar", @@ -346,6 +366,396 @@ func TestCloudFetchIterator(t *testing.T) { assert.Nil(t, nextErr) assert.NotNil(t, sab) }) + + t.Run("should retry transient HTTP 503 and eventually succeed", func(t *testing.T) { + var attempts int32 + handler = func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&attempts, 1) + if n < 3 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.WriteHeader(http.StatusOK) + _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) + if err != nil { + panic(err) + } + } + + startRowOffset := int64(100) + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + cfg.RetryMax = 4 + cfg.RetryWaitMin = 1 * time.Millisecond + cfg.RetryWaitMax = 5 * time.Millisecond + + bi, err := NewCloudBatchIterator( + context.Background(), + []*cli_service.TSparkArrowResultLink{{ + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, + }}, + startRowOffset, + nil, + cfg, + nil, + ) + assert.Nil(t, err) + + sab, nextErr := bi.Next() + assert.Nil(t, nextErr) + assert.NotNil(t, sab) + assert.Equal(t, int32(3), atomic.LoadInt32(&attempts), "expected 2 retries before success") + }) + + t.Run("should retry mid-stream body read failures (200 OK then connection drop)", func(t *testing.T) { + var attempts int32 + realBody := generateMockArrowBytes(generateArrowRecord()) + handler = func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&attempts, 1) + if n == 1 { + writeTruncatedOK(t, w, []byte("partial")) + return + } + w.WriteHeader(http.StatusOK) + if _, err := w.Write(realBody); err != nil { + panic(err) + } + } + + startRowOffset := int64(100) + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + cfg.RetryMax = 4 + cfg.RetryWaitMin = 1 * time.Millisecond + cfg.RetryWaitMax = 5 * time.Millisecond + + bi, err := NewCloudBatchIterator( + context.Background(), + []*cli_service.TSparkArrowResultLink{{ + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, + }}, + startRowOffset, + nil, + cfg, + nil, + ) + assert.Nil(t, err) + + sab, nextErr := bi.Next() + assert.Nil(t, nextErr) + assert.NotNil(t, sab) + assert.Equal(t, int32(2), atomic.LoadInt32(&attempts), "expected first attempt to fail mid-stream, second to succeed") + }) + + t.Run("should fail after exhausting retries on persistent body-read failures", func(t *testing.T) { + var attempts int32 + handler = func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&attempts, 1) + writeTruncatedOK(t, w, nil) + } + + startRowOffset := int64(100) + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + cfg.RetryMax = 2 + cfg.RetryWaitMin = 1 * time.Millisecond + cfg.RetryWaitMax = 5 * time.Millisecond + + bi, err := NewCloudBatchIterator( + context.Background(), + []*cli_service.TSparkArrowResultLink{{ + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, + }}, + startRowOffset, + nil, + cfg, + nil, + ) + assert.Nil(t, err) + + _, nextErr := bi.Next() + assert.NotNil(t, nextErr) + assert.ErrorContains(t, nextErr, "after 2 retries") + // initial attempt + RetryMax retries + assert.Equal(t, int32(3), atomic.LoadInt32(&attempts)) + }) + + t.Run("should retry transient HTTP 500 and eventually succeed", func(t *testing.T) { + var attempts int32 + handler = func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&attempts, 1) + if n < 2 { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) + if err != nil { + panic(err) + } + } + + startRowOffset := int64(100) + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + cfg.RetryMax = 4 + cfg.RetryWaitMin = 1 * time.Millisecond + cfg.RetryWaitMax = 5 * time.Millisecond + + bi, err := NewCloudBatchIterator( + context.Background(), + []*cli_service.TSparkArrowResultLink{{ + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, + }}, + startRowOffset, + nil, + cfg, + nil, + ) + assert.Nil(t, err) + + sab, nextErr := bi.Next() + assert.Nil(t, nextErr) + assert.NotNil(t, sab) + assert.Equal(t, int32(2), atomic.LoadInt32(&attempts)) + }) + + t.Run("should fail after exhausting retries on persistent 503", func(t *testing.T) { + var attempts int32 + handler = func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&attempts, 1) + w.WriteHeader(http.StatusServiceUnavailable) + } + + startRowOffset := int64(100) + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + cfg.RetryMax = 2 + cfg.RetryWaitMin = 1 * time.Millisecond + cfg.RetryWaitMax = 5 * time.Millisecond + + bi, err := NewCloudBatchIterator( + context.Background(), + []*cli_service.TSparkArrowResultLink{{ + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, + }}, + startRowOffset, + nil, + cfg, + nil, + ) + assert.Nil(t, err) + + _, nextErr := bi.Next() + assert.NotNil(t, nextErr) + assert.ErrorContains(t, nextErr, fmt.Sprintf("HTTP error %d", http.StatusServiceUnavailable)) + assert.ErrorContains(t, nextErr, "after 2 retries") + // initial attempt + RetryMax retries + assert.Equal(t, int32(3), atomic.LoadInt32(&attempts)) + }) + + t.Run("should not retry on non-retryable status (403)", func(t *testing.T) { + var attempts int32 + handler = func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&attempts, 1) + w.WriteHeader(http.StatusForbidden) + } + + startRowOffset := int64(100) + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + cfg.RetryMax = 5 + cfg.RetryWaitMin = 1 * time.Millisecond + cfg.RetryWaitMax = 5 * time.Millisecond + + bi, err := NewCloudBatchIterator( + context.Background(), + []*cli_service.TSparkArrowResultLink{{ + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, + }}, + startRowOffset, + nil, + cfg, + nil, + ) + assert.Nil(t, err) + + _, nextErr := bi.Next() + assert.NotNil(t, nextErr) + assert.ErrorContains(t, nextErr, fmt.Sprintf("HTTP error %d", http.StatusForbidden)) + assert.NotContains(t, nextErr.Error(), "after") + assert.Equal(t, int32(1), atomic.LoadInt32(&attempts), "non-retryable status must fail on first attempt") + }) + + t.Run("should detect link expiry between retries", func(t *testing.T) { + var attempts int32 + handler = func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&attempts, 1) + w.Header().Set("Retry-After", "3") + w.WriteHeader(http.StatusServiceUnavailable) + } + + startRowOffset := int64(100) + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + cfg.RetryMax = 5 + cfg.RetryWaitMin = 1 * time.Millisecond + cfg.RetryWaitMax = 3 * time.Second + expiryTime := time.Now().Unix() + 2 + + bi, err := NewCloudBatchIterator( + context.Background(), + []*cli_service.TSparkArrowResultLink{{ + FileLink: server.URL, + ExpiryTime: expiryTime, + StartRowOffset: startRowOffset, + RowCount: 1, + }}, + startRowOffset, + nil, + cfg, + nil, + ) + assert.Nil(t, err) + + _, nextErr := bi.Next() + assert.NotNil(t, nextErr) + assert.ErrorContains(t, nextErr, dbsqlerr.ErrLinkExpired) + // The retry sleeps past expiry, then short-circuits before another GET. + assert.Equal(t, int32(1), atomic.LoadInt32(&attempts)) + }) + + t.Run("should respect context cancellation during backoff", func(t *testing.T) { + var attempts int32 + handler = func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&attempts, 1) + w.WriteHeader(http.StatusServiceUnavailable) + } + + startRowOffset := int64(100) + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + cfg.RetryMax = 5 + cfg.RetryWaitMin = 500 * time.Millisecond + cfg.RetryWaitMax = 1 * time.Second + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + bi, err := NewCloudBatchIterator( + ctx, + []*cli_service.TSparkArrowResultLink{{ + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, + }}, + startRowOffset, + nil, + cfg, + nil, + ) + assert.Nil(t, err) + + started := time.Now() + _, nextErr := bi.Next() + elapsed := time.Since(started) + + assert.NotNil(t, nextErr) + // Cancellation should land well before all retries would otherwise complete + // (5 * 500ms+ = 2.5s+ minimum without cancel). + assert.Less(t, elapsed, 1*time.Second, "context cancel should abort retry backoff promptly") + }) +} + +func TestCloudFetchBackoff(t *testing.T) { + t.Run("retry-after integer seconds is honored", func(t *testing.T) { + got := cloudFetchBackoff(1, 100*time.Millisecond, 60*time.Second, "2") + assert.Equal(t, 2*time.Second, got) + }) + + t.Run("retry-after is capped at waitMax", func(t *testing.T) { + got := cloudFetchBackoff(1, 100*time.Millisecond, 1*time.Second, "100") + assert.Equal(t, 1*time.Second, got) + }) + + t.Run("retry-after http-date is ignored, falls back to exponential", func(t *testing.T) { + minWait := 100 * time.Millisecond + got := cloudFetchBackoff(1, minWait, 10*time.Second, "Tue, 15 Nov 1994 08:12:31 GMT") + // attempt=1 base = minWait; equal jitter in [minWait/2, minWait] + assert.GreaterOrEqual(t, got, minWait/2) + assert.LessOrEqual(t, got, minWait) + }) + + t.Run("exponential is capped at waitMax", func(t *testing.T) { + maxWait := 200 * time.Millisecond + // 100ms * 2^9 = 51200ms, capped at 200ms; equal jitter -> [100ms, 200ms] + for i := 0; i < 50; i++ { + got := cloudFetchBackoff(10, 100*time.Millisecond, maxWait, "") + assert.GreaterOrEqual(t, got, maxWait/2) + assert.LessOrEqual(t, got, maxWait) + } + }) + + t.Run("base grows exponentially with attempt", func(t *testing.T) { + minWait, maxWait := 100*time.Millisecond, 10*time.Second + // attempt=1 -> base 100ms, jitter [50ms,100ms] + // attempt=3 -> base 400ms, jitter [200ms,400ms] + for i := 0; i < 50; i++ { + got1 := cloudFetchBackoff(1, minWait, maxWait, "") + got3 := cloudFetchBackoff(3, minWait, maxWait, "") + assert.GreaterOrEqual(t, got1, 50*time.Millisecond) + assert.LessOrEqual(t, got1, 100*time.Millisecond) + assert.GreaterOrEqual(t, got3, 200*time.Millisecond) + assert.LessOrEqual(t, got3, 400*time.Millisecond) + } + }) + + t.Run("zero waitMin returns zero", func(t *testing.T) { + got := cloudFetchBackoff(1, 0, 0, "") + assert.Equal(t, time.Duration(0), got) + }) +} + +func TestCloudFetchRetryableStatus(t *testing.T) { + retryable := []int{408, 429, 500, 502, 503, 504} + notRetryable := []int{200, 201, 301, 302, 400, 401, 403, 404, 409, 410, 501} + + for _, s := range retryable { + assert.True(t, isCloudFetchRetryableStatus(s), "%d should be retryable", s) + } + for _, s := range notRetryable { + assert.False(t, isCloudFetchRetryableStatus(s), "%d should not be retryable", s) + } } func TestCloudFetchSchemaOverride(t *testing.T) { @@ -686,6 +1096,116 @@ func TestCloudFetchIterator_CloseReleasesInFlightDownloads(t *testing.T) { countDownloadTaskGoroutines()) } +// TestCloudFetchIterator_CloseReleasesAfterRetry is the retry-path counterpart +// to TestCloudFetchIterator_CloseReleasesInFlightDownloads. PR #355 added +// HTTP retry to fetchBatchBytes, materially lengthening the window during +// which a download task can produce a result after the iterator has been +// closed. The result-send must therefore go through cft.sendResult so the +// ctx.Done arm fires and the goroutine exits; a regression that routes the +// final result back through `cft.resultChan <-` (e.g. a sloppy merge of +// #355 onto pre-#357 main) blocks forever and pins the buffered Arrow body +// in the heap. +// +// The server flaps once (503 then 200) to force the retry path so the task +// produces its result after at least one backoff. The first task's result +// is consumed by the foreground Next(); the remaining MaxDownloadThreads-1 +// tasks have queued results that nobody is reading. bi.Close() must release +// them. +func TestCloudFetchIterator_CloseReleasesAfterRetry(t *testing.T) { + arrowBytes := generateMockArrowBytes(generateArrowRecord()) + + // Each link is requested at most twice: first attempt returns 503, + // second returns the body. Tracked per-link via a map keyed on the + // link's StartRowOffset so MaxDownloadThreads parallel requests stay + // independent. + var mu sync.Mutex + attempts := map[string]int{} + var inFlightSecond atomic.Int64 + release := make(chan struct{}) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Use the query string injected per-link below as the key. + key := r.URL.RawQuery + mu.Lock() + attempts[key]++ + n := attempts[key] + mu.Unlock() + + if n == 1 { + // First attempt: serve a retryable 503 so the task enters the + // retry/backoff path. + w.WriteHeader(http.StatusServiceUnavailable) + return + } + // Second attempt: block until the test releases, then return + // success. The block lets us observe all MaxDownloadThreads tasks + // having made it past the retry and parked on the channel send. + inFlightSecond.Add(1) + <-release + w.WriteHeader(http.StatusOK) + _, _ = w.Write(arrowBytes) + })) + defer server.Close() + + const nLinks = 20 + links := make([]*cli_service.TSparkArrowResultLink, nLinks) + for i := range links { + links[i] = &cli_service.TSparkArrowResultLink{ + // Unique query string per link so the server's per-link + // attempt counter doesn't conflate parallel requests. + FileLink: fmt.Sprintf("%s/?id=%d", server.URL, i), + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: int64(i), + RowCount: 1, + } + } + + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 10 + // Backoff is observable but short — the test should still complete in + // well under a second. + cfg.RetryMax = 3 + cfg.RetryWaitMin = 10 * time.Millisecond + cfg.RetryWaitMax = 50 * time.Millisecond + + bi, err := NewCloudBatchIterator(context.Background(), links, 0, nil, cfg, nil) + assert.Nil(t, err) + + // One foreground reader. It drains exactly one batch and then exits; + // the remaining MaxDownloadThreads-1 in-flight tasks have nobody reading. + go func() { _, _ = bi.Next() }() + + // Wait for every concurrent task to have completed its 503 retry, slept + // the backoff, and reached its second-attempt handler (where it's + // parked waiting for the release channel). + assert.Eventually(t, func() bool { + return inFlightSecond.Load() == int64(cfg.MaxDownloadThreads) + }, 5*time.Second, 10*time.Millisecond, + "expected %d second-attempt requests, got %d", + cfg.MaxDownloadThreads, inFlightSecond.Load()) + + // Release: every task now finishes its successful HTTP read and + // attempts to send its result. The foreground Next() consumes one; + // the rest are queued in cloudIPCStreamIterator.downloadTasks and the + // goroutines park on resultChan (via sendResult). + close(release) + + // Give the goroutines time to reach the channel send. + time.Sleep(200 * time.Millisecond) + + // Close without draining. cft.ctx is cancelled for every queued task, + // so each parked sendResult must unblock via its ctx.Done arm. + bi.Close() + + // All cloudFetchDownloadTask.Run goroutines must exit. + assert.Eventually(t, func() bool { + return countDownloadTaskGoroutines() == 0 + }, 5*time.Second, 50*time.Millisecond, + "cloudFetchDownloadTask goroutines leaked after Close on retry path: have %d", + countDownloadTaskGoroutines()) +} + // countDownloadTaskGoroutines returns the number of live goroutines whose // stack includes cloudFetchDownloadTask.Run. Used to detect the leak in // issue #356.