I am trying to learn the weights of a 3x3 conv2d layer accepting 3 channels and outputting 3 channels. For this discussion consider bias=0 in each case. However, the weights of the conv layer are learned indirectly. I have a 2 layered Multi layer perception having 9 nodes in first layer and 9 in the second. The weights of the 2d conv layer are then precisely the weights learned using this MLP i.e. nn.Linear(9,9)
. I understand in this case I will have to use nn.functional.conv2d(input,weight)
. But how exactly to extract the weights from MLP and use it for convolution is not clear and can think of the following.
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.hyper = nn.Linear(9, 9) # output the required number of parameters
def forward(self, x):
# do stuff with self.hyper(x)
y = nn.Functional.conv2d(x, self.hyper.weight.reshape((3, 3, 3, 3))) # add padding and other parameters
return y
Can some one provide a short, dummy code in PyTorch to achieve this training configuration allowing backpropagation? Will the above code be problematic for PyTorch’s backpropagation?