I am getting this error after I implemented hypernetwork to learn weights for the CNN. This is e.g. how I define the weights:
'conv0_0_weight': torch.Size([c//2, in_planes, 4, 4, 4])
This is how it is used:
x = functional_conv3d(x, weights['conv0_0_weight'], weights['conv0_0_bias'], self.c//2, stride=2, padding=1)
and this is the functional conv3d:
def functional_conv3d(x, weight, bias, out_planes, stride=1, padding=1, dilation=1):
x = F.conv3d(x, weight, bias, stride=stride, padding=padding, dilation=dilation)
x = nn.PReLU(num_parameters=out_planes)(x)
return x
I printed the shapes of the input tensor and weights:
# input x: torch.Size([2, 3, 128, 128, 128])
# conv0_0_weight : torch.Size([2, 128, 3, 4, 4, 4])
The first value (of 2) is the batch size. I checked the documentation and c//2, in_planes, 4, 4, 4 in the weight shape initialization are according to out_channels, in_channels, kT, kH, kW. Is there something I am missing?