Skip to content

optimize_for_inference() leaks CUDA context onto unassigned (default) GPU #898

@vmiller987

Description

@vmiller987

Search before asking

  • I have searched the RF-DETR issues and found no similar bug report.

Bug

optimize_for_inference() leaks CUDA context onto GPU 0 (~386 MiB) even when the model lives on a different GPU. Without optimize_for_inference(), the model stays on the target device.

The method uses deepcopy on a CUDA model and torch.jit.trace (detr.py:510-531), both of which can implicitly initialize a CUDA context on the default device (GPU 0). This memory appears in nvidia-smi but not in torch.cuda.memory_allocated(), confirming it is driver level CUDA context overhead rather than tensor allocations.

Environment

RF-DETR: 1.6.2
OS: RHEL 9.7
Python: 3.13
PyTorch: 2.10 - 2.11
CUDA: 12.9 / Driver: 590.48.01
GPU: 8x NVIDIA GeForce RTX 4090

Minimal Reproducible Example

Without optimize_for_inference, no leak:

import numpy as np
from rfdetr import RFDETRNano

model = RFDETRNano(device="cuda:6")
x = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
model.predict(x)
/home/vmiller/Work/test/.venv/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
[2026-03-31 10:57:44] [INFO] rf-detr - File rf-detr-nano.pth already exists with correct MD5 hash.
[2026-03-31 10:57:44] [WARNING] rf-detr - Using a different number of positional encodings than DINOv2, which means we're not loading DINOv2 backbone weights. This is not a problem if finetuning a pretrained RF-DETR model.
[2026-03-31 10:57:44] [WARNING] rf-detr - Using patch size 16 instead of 14, which means we're not loading DINOv2 backbone weights. This is not a problem if finetuning a pretrained RF-DETR model.
[2026-03-31 10:57:45] [INFO] rf-detr - File rf-detr-nano.pth already exists with correct MD5 hash.
[2026-03-31 10:57:46] [WARNING] rf-detr - Model is not optimized for inference. Latency may be higher than expected. You can optimize the model for inference by calling model.optimize_for_inference().
`use_return_dict` is deprecated! Use `return_dict` instead!

Detections(xyxy=array([], shape=(0, 4), dtype=float32), mask=None, confidence=array([], dtype=float32), class_id=array([], dtype=int64), tracker_id=None, data={}, metadata={})

On GPU 6

[vmiller@gluskap ~]$ nvidia-smi
Tue Mar 31 10:58:52 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 590.48.01              Driver Version: 590.48.01      CUDA Version: 13.1     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4090        Off |   00000000:01:00.0 Off |                  Off |
| 30%   34C    P2             58W /  450W |     544MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 4090        Off |   00000000:21:00.0 Off |                  Off |
|  0%   35C    P8             19W /  450W |       4MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA GeForce RTX 4090        Off |   00000000:41:00.0 Off |                  Off |
|  0%   38C    P8             12W /  450W |       4MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA GeForce RTX 4090        Off |   00000000:81:00.0 Off |                  Off |
|  0%   40C    P8             17W /  450W |       4MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA GeForce RTX 4090        Off |   00000000:A1:00.0 Off |                  Off |
|  0%   38C    P8             11W /  450W |       4MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA GeForce RTX 4090        Off |   00000000:C1:00.0 Off |                  Off |
|  0%   33C    P8             10W /  450W |       4MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA GeForce RTX 4090        Off |   00000000:E1:00.0 Off |                  Off |
|  0%   34C    P8              4W /  450W |     644MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    6   N/A  N/A         1449007      C   ...er/Work/test/.venv/bin/python        634MiB |
+-----------------------------------------------------------------------------------------+

With optimize_for_inference, leaks onto GPU 0:

import torch
import numpy as np
from rfdetr import RFDETRNano

model = RFDETRNano(device="cuda:6")
model.optimize_for_inference(compile=True, batch_size=1, dtype=torch.float32)

x = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
model.predict(x)
[2026-03-31 10:59:47] [INFO] rf-detr - File rf-detr-nano.pth already exists with correct MD5 hash.
[2026-03-31 10:59:47] [WARNING] rf-detr - Using a different number of positional encodings than DINOv2, which means we're not loading DINOv2 backbone weights. This is not a problem if finetuning a pretrained RF-DETR model.
[2026-03-31 10:59:47] [WARNING] rf-detr - Using patch size 16 instead of 14, which means we're not loading DINOv2 backbone weights. This is not a problem if finetuning a pretrained RF-DETR model.
[2026-03-31 10:59:48] [INFO] rf-detr - File rf-detr-nano.pth already exists with correct MD5 hash.

Detections(xyxy=array([], shape=(0, 4), dtype=float32), mask=None, confidence=array([], dtype=float32), class_id=array([], dtype=int64), tracker_id=None, data={}, metadata={})

On GPU 0 and GPU 6

[vmiller@gluskap ~]$ nvidia-smi
Tue Mar 31 11:00:18 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 590.48.01              Driver Version: 590.48.01      CUDA Version: 13.1     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4090        Off |   00000000:01:00.0 Off |                  Off |
| 30%   34C    P2             57W /  450W |     935MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 4090        Off |   00000000:21:00.0 Off |                  Off |
|  0%   35C    P8             19W /  450W |       4MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA GeForce RTX 4090        Off |   00000000:41:00.0 Off |                  Off |
|  0%   38C    P8             12W /  450W |       4MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA GeForce RTX 4090        Off |   00000000:81:00.0 Off |                  Off |
|  0%   40C    P8             16W /  450W |       4MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA GeForce RTX 4090        Off |   00000000:A1:00.0 Off |                  Off |
|  0%   38C    P8             11W /  450W |       4MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA GeForce RTX 4090        Off |   00000000:C1:00.0 Off |                  Off |
|  0%   33C    P8             10W /  450W |       4MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA GeForce RTX 4090        Off |   00000000:E1:00.0 Off |                  Off |
|  0%   35C    P8              4W /  450W |    1078MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A         1449007      C   ...er/Work/test/.venv/bin/python        386MiB |
|    6   N/A  N/A         1449007      C   ...er/Work/test/.venv/bin/python       1068MiB |
+-----------------------------------------------------------------------------------------+

Additional

Related

Suggested Fix

Wrap optimize_for_inference in the target device's context manager to ensure any implicit CUDA initialization
happens on the correct GPU:

def optimize_for_inference(self, compile=True, batch_size=1, dtype=torch.float32):
    with torch.cuda.device(self.model.device):
        self.remove_optimized_model()

        self.model.inference_model = deepcopy(self.model.model)
        self.model.inference_model.eval()
        self.model.inference_model.export()

        self._optimized_resolution = self.model.resolution
        self._is_optimized_for_inference = True

        self.model.inference_model = self.model.inference_model.to(dtype=dtype)
        self._optimized_dtype = dtype

        if compile:
            self.model.inference_model = torch.jit.trace(
                self.model.inference_model,
                torch.randn(
                    batch_size, 3, self.model.resolution, self.model.resolution,
                    device=self.model.device, dtype=dtype
                ),
            )
            self._optimized_has_been_compiled = True
            self._optimized_batch_size = batch_size

Minor related note:

optimize_for_inference passes dtype directly to nn.Module.to(dtype=...) at detr.py:520. When a user passes dtype="float32" (a string), this raises a TypeError because .to() expects torch.dtype. Consider adding string-to-dtype coercion or documenting that only torch.dtype values are accepted.

Are you willing to submit a PR?

  • Yes, I'd like to help by submitting a PR!

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions