PyTorch DataLoader pre-fetched GPU tensor raises warnings

I’m trying to define a DataLoader that pre-fetches tensors directly into GPU memory (not pinned memory) in a separate process. I want to implement this so that the main process doesn’t have to wait while data is transferred from CPU to GPU for every batch, and also so I don’t have to check in my training loop whether or not the data needs to be in device memory or not and then manually transfer it, because this DataLoader puts the data into the correct memory automatically. I defined my data-loader like this:

def collate_gpu(batch):
    x, t = torch.utils.data.default_collate(batch)
    return x.to(device="cuda:0"), t.to(device="cuda:0")

kwargs = {
    "num_workers": 1,
    "prefetch_factor": 2,
    "persistent_workers": True,
    "collate_fn": collate_gpu,
}
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    **kwargs,
)

When my program exits, I get a large number of CUDA warnings:

[W CUDAGuardImpl.h:46] Warning: CUDA warning: driver shutting down (function uncheckedGetDevice)
[W CUDAGuardImpl.h:62] Warning: CUDA warning: invalid device ordinal (function uncheckedSetDevice)
[W CUDAGuardImpl.h:46] Warning: CUDA warning: driver shutting down (function uncheckedGetDevice)
[W CUDAGuardImpl.h:62] Warning: CUDA warning: invalid device ordinal (function uncheckedSetDevice)
[W CUDAGuardImpl.h:46] Warning: CUDA warning: driver shutting down (function uncheckedGetDevice)
[W CUDAGuardImpl.h:62] Warning: CUDA warning: invalid device ordinal (function uncheckedSetDevice)
[W CUDAGuardImpl.h:46] Warning: CUDA warning: driver shutting down (function uncheckedGetDevice)
[W CUDAGuardImpl.h:62] Warning: CUDA warning: invalid device ordinal (function uncheckedSetDevice)
[W CUDAGuardImpl.h:46] Warning: CUDA warning: driver shutting down (function uncheckedGetDevice)
[W CUDAGuardImpl.h:62] Warning: CUDA warning: invalid device ordinal (function uncheckedSetDevice)
[W CUDAGuardImpl.h:46] Warning: CUDA warning: driver shutting down (function uncheckedGetDevice)
[W CUDAGuardImpl.h:62] Warning: CUDA warning: invalid device ordinal (function uncheckedSetDevice)
[W CUDAGuardImpl.h:46] Warning: CUDA warning: driver shutting down (function uncheckedGetDevice)
[W CUDAGuardImpl.h:62] Warning: CUDA warning: invalid device ordinal (function uncheckedSetDevice)
[W CUDAGuardImpl.h:46] Warning: CUDA warning: driver shutting down (function uncheckedGetDevice)
[W CUDAGuardImpl.h:62] Warning: CUDA warning: invalid device ordinal (function uncheckedSetDevice)

If I remove "persistent_workers": True, I get similar warnings every time an iterator finishes iterating over train_loader, in addition to the following warning:

[W C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\CudaIPCTypes.cpp:15] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]

If I remove "num_workers": 1, "prefetch_factor": 2, "persistent_workers": True, I don’t get any warnings at all, but then also I don’t get any benefit from pre-loading the data into GPU memory.

It seems apparent that the warnings are caused by the worker process for the DataLoader being killed while it still has pre-fetched data on the GPU which hasn’t been freed.

Is there any way I can free the pre-fetched GPU data in the worker process before the train_loader object is destroyed (which I assume should fix the warnings)?

(PS I know it is possible to silence CUDA warnings, but here I am interested in solving the root cause of these warnings).

Edit: minimal working example (MWE)

The following MWE recreates the warning messages:

import torch
import torchvision

def collate_gpu(batch):
    x, t = torch.utils.data.default_collate(batch)
    return x.to(device=0), t.to(device=0)

train_dataset = torchvision.datasets.MNIST(
    './data',
    train=True,
    download=True,
    transform=torchvision.transforms.ToTensor(),
)
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=100,
    num_workers=1,
    prefetch_factor=2,
    persistent_workers=True,
    collate_fn=collate_gpu,
)

if __name__ == "__main__":
    x, t = next(iter(train_loader))
    print("About to call `del train_loader`...")
    del train_loader
    print("Finished `del train_loader`")

Console output:

About to call `del train_loader`...
[W C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\CudaIPCTypes.cpp:15] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]
[W CUDAGuardImpl.h:46] Warning: CUDA warning: driver shutting down (function uncheckedGetDevice)
[W CUDAGuardImpl.h:62] Warning: CUDA warning: driver shutting down (function uncheckedSetDevice)
[W CUDAGuardImpl.h:46] Warning: CUDA warning: driver shutting down (function uncheckedGetDevice)
[W CUDAGuardImpl.h:62] Warning: CUDA warning: invalid device ordinal (function uncheckedSetDevice)
[W CUDAGuardImpl.h:46] Warning: CUDA warning: driver shutting down (function uncheckedGetDevice)
[W CUDAGuardImpl.h:62] Warning: CUDA warning: invalid device ordinal (function uncheckedSetDevice)
[W CUDAGuardImpl.h:46] Warning: CUDA warning: driver shutting down (function uncheckedGetDevice)
[W CUDAGuardImpl.h:62] Warning: CUDA warning: invalid device ordinal (function uncheckedSetDevice)
[W CUDAGuardImpl.h:46] Warning: CUDA warning: driver shutting down (function uncheckedGetDevice)
[W CUDAGuardImpl.h:62] Warning: CUDA warning: invalid device ordinal (function uncheckedSetDevice)
[W CUDAGuardImpl.h:46] Warning: CUDA warning: driver shutting down (function uncheckedGetDevice)
[W CUDAGuardImpl.h:62] Warning: CUDA warning: invalid device ordinal (function uncheckedSetDevice)
Finished `del train_loader`

In reality I don’t call del train_loader, but I initialise train_loader inside a function, and when the function exits, the result is the same.

Weirdly, if I don’t call del train_loader (and train_loader is not defined inside a function), then there are no warning messages at all.