Setting custom kernel for CNN in pytorch

You would need to set requires_grad=True for the weights and it would also work as nn.Conv2d internally just calls the functional API, see here. :wink:

However, if you prefer to use the module, you could try the following code:

weights = ...
conv = nn.Conv2d(nb_channels, 1, 3, bias=False)
with torch.no_grad():
    conv.weight = nn.Parameter(weights)

output = conv(x)
output.mean().backward()
print(conv.weight.grad)
12 Likes