I have a 3D tensor with each channel to be convolved with the one single kernel. From a quick search, the fastest way to do this was to use grouped convolution with number of groups to be the number of channels.
Here is a small reproducible example:
import torch
import torch.nn as nn
torch.manual_seed(0)
x = torch.rand(1, 3, 3, 3)
first = x[:, 0:1, ...]
second = x[:, 1:2, ...]
third = x[:, 2:3, ...]
kernel = nn.Conv2d(1, 1, 3)
conv = nn.Conv2d(3, 3, 3, groups=3)
conv.weight.data = kernel.weight.data.repeat(3, 1, 1, 1)
conv.bias.data = kernel.bias.data.repeat(3)
>>> conv(x)
tensor([[[[-1.0085]],
[[-1.0068]],
[[-1.0451]]]], grad_fn=<MkldnnConvolutionBackward>)
>>> kernel(first), kernel(second), kernel(third)
(tensor([[[[-1.0085]]]], grad_fn=<ThnnConv2DBackward>),
tensor([[[[-1.0068]]]], grad_fn=<ThnnConv2DBackward>),
tensor([[[[-1.0451]]]], grad_fn=<ThnnConv2DBackward>))
Which you can see perfectly works.
Now coming to my question. I need to do backprop on this (kernel
object). While doing this, each weight of the conv
gets its own update. But actually, conv
is made up of kernel
repeated 3 times. At the end I require only an updated kernel
. How do I do this?
PS: I need to optimize for speed