I am using conv2d with large filters and my program “hangs”. Specifically, I have two (1 channel) 2048 X 2048 images. My convolution kernel is 501x501 with 48 channels i.e. I am expecting a 1 X 48 X 1548 X 1548 tensor ax output. My code below “freezes” after “out = F.conv2d(input, weight)” . What am I doing wrong?
import torch
import torch.nn as nn
import torch.nn.functional as F
device = 'cuda:0'
inputs = [torch.zeros((1, 1, 2048, 2048)).to(device)]*2
weight = torch.ones((48, 1, 501, 501), dtype=torch.float32).to(device)
for (i, input) in enumerate(inputs):
out = F.conv2d(input, weight)
print(out.shape)
print(torch.argmax(out))
EDIT: The program finished after a many minutes!! I am still unable to understand the reason for this long processing time.