PyTorch and Conv2d

I am using conv2d with large filters and my program “hangs”. Specifically, I have two (1 channel) 2048 X 2048 images. My convolution kernel is 501x501 with 48 channels i.e. I am expecting a 1 X 48 X 1548 X 1548 tensor ax output. My code below “freezes” after “out = F.conv2d(input, weight)” . What am I doing wrong?

import torch
import torch.nn as nn
import torch.nn.functional as F

device = 'cuda:0'
inputs = [torch.zeros((1, 1, 2048, 2048)).to(device)]*2
weight = torch.ones((48, 1, 501, 501), dtype=torch.float32).to(device)
for (i, input) in enumerate(inputs):
    out = F.conv2d(input, weight)
    print(out.shape)
    print(torch.argmax(out))

EDIT: The program finished after a many minutes!! I am still unable to understand the reason for this long processing time.

The code takes a few seconds on my system but the actual performance depends on the used hardware, lib versions etc. The conv layer most likely does not support an optimized algorithm for these shapes and a matmul approach would be taken.

Thank you for the response @ptrblck

I understand that I maybe trying shapes that are not optimized. My confusion arose from the following:
For i=0, the out.shape is printed fast, but the program seems to slow down after that.

Is there is a max kernel size for which optimized algorithms are available?