Grouped Conv for Repeated Convolution Backprop

I have a 3D tensor with each channel to be convolved with the one single kernel. From a quick search, the fastest way to do this was to use grouped convolution with number of groups to be the number of channels.

Here is a small reproducible example:

import torch
import torch.nn as nn
torch.manual_seed(0)


x = torch.rand(1, 3, 3, 3)
first  = x[:, 0:1, ...]
second = x[:, 1:2, ...]
third  = x[:, 2:3, ...]

kernel = nn.Conv2d(1, 1, 3)
conv = nn.Conv2d(3, 3, 3, groups=3)
conv.weight.data = kernel.weight.data.repeat(3, 1, 1, 1)
conv.bias.data = kernel.bias.data.repeat(3)

>>> conv(x)
tensor([[[[-1.0085]],

         [[-1.0068]],

         [[-1.0451]]]], grad_fn=<MkldnnConvolutionBackward>)

>>> kernel(first), kernel(second), kernel(third)
(tensor([[[[-1.0085]]]], grad_fn=<ThnnConv2DBackward>),
 tensor([[[[-1.0068]]]], grad_fn=<ThnnConv2DBackward>),
 tensor([[[[-1.0451]]]], grad_fn=<ThnnConv2DBackward>))

Which you can see perfectly works.

Now coming to my question. I need to do backprop on this (kernel object). While doing this, each weight of the conv gets its own update. But actually, conv is made up of kernel repeated 3 times. At the end I require only an updated kernel. How do I do this?

PS: I need to optimize for speed

One possible answer is to take a mean after the gradient updates like so

kernel.weight.data = conv.weight.data.mean(0).unsqueeze(0)

Is this the best way to do it. Or is this even right in the first place?

For anyone stumbling across the same doubt, find the answer here