CUDA Error/RuntimeError when using decode_jpeg() in dataloader

I’m trying to decode a large number of jpeg images directly onto my GPU via a DataLoader, but am running into issues.

I’m running with batch_size=64 and num_workers=8.

The code runs as expected when I pass device="cpu" to decode_jpeg().

My dataset is defined as follows:

class MyDataset(Dataset):
    def __init__(self):
        super(MyDataset).__init__()
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.files = sorted(glob.glob("my_dir/*.jpeg"))

    def __getitem__(self, idx):
        im_u8 = read_file(self.files[idx])
        im_nv = decode_jpeg(im_u8, device=self.device).float() / 255
        return im_nv

    def __len__(self):
        return len(self.files)

Initially I was running into a CUDA initialization error, which I fixed by setting mp.set_start_method("spawn"). I am now running into the following cryptic RuntimeError:

Traceback (most recent call last):
File “/code/scripts/predict-noncapture-hauls.py”, line 161, in
run()
File “/usr/local/lib/python3.9/dist-packages/click/core.py”, line 1130, in call
return self.main(*args, **kwargs)
File “/usr/local/lib/python3.9/dist-packages/click/core.py”, line 1055, in main
rv = self.invoke(ctx)
File “/usr/local/lib/python3.9/dist-packages/click/core.py”, line 1404, in invoke
return ctx.invoke(self.callback, **ctx.params)
File “/usr/local/lib/python3.9/dist-packages/click/core.py”, line 760, in invoke
return __callback(*args, **kwargs)
File “/code/scripts/predict-noncapture-hauls.py”, line 147, in run
predictions, metadata = evaluate_event_yolov5(
File “/code/scripts/predict-noncapture-hauls.py”, line 75, in evaluate_event_yolov5
for batch in dataloader:
File “/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py”, line 530, in next
data = self._next_data()
File “/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py”, line 1224, in _next_data
return self._process_data(data)
File “/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py”, line 1250, in _process_data
data.reraise()
File “/usr/local/lib/python3.9/dist-packages/torch/_utils.py”, line 457, in reraise
raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File “/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/worker.py”, line 287, in _worker_loop
data = fetcher.fetch(index)
File “/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/fetch.py”, line 52, in fetch
return self.collate_fn(data)
File “/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/collate.py”, line 136, in default_collate
storage = elem.storage()._new_shared(numel)
File “/usr/local/lib/python3.9/dist-packages/torch/_tensor.py”, line 181, in storage
storage = self._storage()
RuntimeError: it != attype_to_py_storage_type.end()INTERNAL ASSERT FAILED at “…/torch/csrc/DynamicTypes.cpp”:69, please report a bug to PyTorch. Failed to get the Python type of _UntypedStorage.

What type is im_nv using? It seems as if the type cannot be mapped to a storage.

im_nv is of type torch.float32 and shape 3x640x640. im_u8 has type torch.uint8.

Note that it works as expected when decoding to CPU.

Edit: Replied earlier saying I was no longer receiving this issue. It turns out I’d increase the batch size, which was causing a different issue. I am still receiving the RuntimeError.

This is really odd. I put a flag in after the decode_jpeg() line, which showed that it is actually successfully loading the first few images (possibly all of the images in each worker’s batch?), and then breaking.

It looks like it’s breaking at the collate step - could it be an issue with the tensors in each batch being on gpu prior to collation?

Update: moving the tensors back to CPU in __getitem__ stops the error from occurring. Though this isn’t ideal since we then move the batches back onto GPU.

A further strange observation: When I load the jpeg in via OpenCV, turn it into a tensor, and move that onto GPU, it works fine as expected.

For reference, loading an image via torchvision.io.decode_jpeg():

im_u8 = read_file(self.files[idx])
im_nv = decode_jpeg(im_u8, device="cuda").float() / 255

versus via OpenCV:

im = cv2.imread(self.files[idx])  # BGR
im = im.transpose((2, 0, 1))[::-1]
im = np.ascontiguousarray(im)
im = torch.from_numpy(im).float() / 255
im = im.to("cuda")

The produced tensors are not exactly the same (minor decoding differences I guess) but are close. They are the same shape and data type. The OpenCV option works as expected, while the PyTorch one produces our friend the RuntimeError :confused:

(Note, I tried making im_nv contiguous to no avail)

Thanks for the follow-up! I’m currently unsure what’s causing the issue, but could you create an issue on the torchvision repository so that we can track and debug the issue further, please?

Sure, will condense it down a bit :sweat_smile:. Thanks for your help so far!

I tried to reproduce the issue using torchvision==0.13.0+cu116 (the RC) with a single image and this code:

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import torch.multiprocessing as mp

class MyDataset(Dataset):
    def __init__(self):
        super(MyDataset).__init__()
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.path = 'fake.jpeg'

    def __getitem__(self, idx):
        im_u8 = torchvision.io.read_file(self.path)
        im_nv = torchvision.io.decode_jpeg(im_u8, device=self.device).float() / 255
        return im_nv

    def __len__(self):
        return 10
    
def main():
    dataset = MyDataset()
    loader = DataLoader(dataset, num_workers=2)
    
    for data in loader:
        print(data)
        
if __name__=="__main__":
    mp.set_start_method("spawn")
    main()

which seems to work fine. Could you also try out the latest nightly release?

I just tried your example with the nightly release, and am no longer getting the RuntimeError. Unfortunately I’m running into the memory leak bug known to occur before CUDA 11.6 (although my machine does seem to be running 11.6 :person_shrugging:).

I wasn’t able to find the distribution of torchvision you mentioned. I can only see the CUDA 10.2 and 11.3 versions mentioned on the website, could you let me know how you installed it?

This should work:

pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cu116

Could you also point me to the known memory leak issue in previous versions, please?

Here is a a thread on the memory leak issue: Torchvision decode_jpeg memory leak · Issue #4378 · pytorch/vision · GitHub

I’m no longer getting the warning on torchvision==0.13.0+cu116, but am still running into memory issues. Using a smaller batch size stops the GPU memory issues, though this is substantially slower than running the CPU decoding with a larger batch size, for some reason.

I am also receiving the following warning when decoding on GPU:

[W CudaIPCTypes.cpp:15] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]

I’m currently investigating to see whether this could be related to the memory leak issues.