Image much larger than expected on GPU

I’m using pytorch/TensorRT for real time inference, this means that I’m taking an image directly from the camera followed by processing the image (rescaling etc) and then finally inference. I found the numpy/opencv processing pretty slow so I wanted to do this on GPU as either way I would have to copy the image to GPU.

This is the code that I’m using:

def pre_process_torch_cuda(img: np.ndarray, img_size=1280):
    or_h, or_w = img.shape[:2]

    scale_factor = min(img_size / or_h, img_size / or_w)
    # img =fast_rescale(img, scale_factor)
    # new_h, new_w, _ = img.shape
    new_h = int(scale_factor * or_h)
    new_w = int(scale_factor * or_w)

    print((new_h, new_w))

    # torch_img = torch.Tensor(img)
    cuda_img = torch.from_numpy(img).float().to(torch.device('cuda'))
    res_img = F.resize(cuda_img, size=[new_h, new_w])
    cor_chan_img = res_img.float().permute(2, 0, 1) / 255

    canvas = torch.zeros((img_size, img_size))
    canvas[0, :, 0:new_h, 0:new_w] = cor_chan_img

The input size are images with size 1936x1216x3, by my calculations these should be around 7x10e6 pixels and with them being floats this should be around 28x10e6bytes => 28MB. However, when I do this line cuda_img = torch.from_numpy(img).float().to(torch.device('cuda')) takes in 2GB GPU memory, by the end of a single image transformation I’m OOM on 12GB VRAM. Anyone able to help with this problem?

The first CUDA operation will create the CUDA context which will load the driver, all native kernels in PyTorch, CUDA math libraries (cuDNN, cublas etc.) and will thus create this overhead. Depending on the linking strategy used, the GPU architecture, CUDA version etc. the size of the CUDA context might differ.

Based on your description it seems as if this call is responsible for the OOM error, which sounds strange, as the CUDA context creation would have been done in any case.
Could you check if you are creating multiple contexts and are thus losing memory?

Thanks for the reply. This line is not responsible for the OOM, but the 2Gb seemed excessive, however if this indeed loads the cuda context it is logical. However the memory usage rises to 11.7Gb during the next few lines (7.3Gb after resizing and then OOM after the last line). This function is simply run once with one image, so the behavior seems weird.

I think you are hitting the issue since you are passing the image in a channels-last memory layout, while PyTorch expects channels-first.
While this would generally work, your approach would resize the image in the channel dimension and thus increase the memory usage massively:

print(torch.cuda.memory_allocated()/1024**3)
# 0.0

img = np.random.randn(1936, 1216, 3)

img_size=1280
or_h, or_w = img.shape[:2]

scale_factor = min(img_size / or_h, img_size / or_w)
new_h = int(scale_factor * or_h)
new_w = int(scale_factor * or_w)

print((new_h, new_w))
# (1280, 803)

# torch_img = torch.Tensor(img)
cuda_img = torch.from_numpy(img).float().to(torch.device('cuda'))
print(torch.cuda.memory_allocated()/1024**3)
# 0.026309967041015625

res_img = torchvision.transforms.functional.resize(cuda_img, size=[new_h, new_w])
print(torch.cuda.memory_allocated()/1024**3)
# 7.439258575439453

print(res_img.shape)
# torch.Size([1936, 1280, 803])

print(res_img.nelement() * res_img.element_size() / 1024**3)
# 7.4129486083984375

permute the tensor before calling resize and the memory usage should be reduced.