I am getting
RuntimeError: expected stride to be a single integer value or a list of 3 values to match the convolution dimensions, but got stride=[1, 1]
. Other posts say that this error occurs when there is batch dimension is missing, but it seems I get proper dimensions, the output of print is x.shape in torch.Size([5, 512, 14, 14])
. So what might be a problem here?
class SimpleConvo(nn.Module):
def __init__(self, channels, kernel_size, dim=2):
super(SimpleConvo, self).__init__()
kernel = torch.ones(size=(channels, kernel_size, kernel_size))
kernel = kernel.view(1, 1, *kernel.size())
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
self.register_buffer('weight', kernel)
self.groups = channels
self.conv = F.conv2d if dim == 2 else F.conv3d
def forward(self, x):
print("x.shape in ", x.shape)
x = self.conv(x, weight=self.weight.cuda(), groups=self.groups)
return x
EDIT:
Usage:
convos = SimpleConvo(channels=512, kernel_size=3)
mat = convos(mat)