How to implement a kernel-wise convolution

Hi, I’m recently designing a convolution method, in which the kernel multiply a different matrix element-wisely at each step. My current implementation is something like:

self.unfold = nn.unfold(kernel_size, padding, dilation, stride)
self.w = nn.Parameter(torch.empty(in_channel, kernel_size, kernel_size, out_channel))
self.out_channel = out_channel

a = self.unfold(a) # a is the same width and length as img, but only one channel
img = self.unfold(img)
img_ = torch.mul(a.unsqueeze(1), img.view(batch, 3, kernel_size^2, -1))
output = torch.matmul(img_.view(batch, 3*kernel_size^2, -1).permute(0, 2, 1), self.w.view(-1, self.out_channel))

I follow the unfold way to implement this convolution. But when I run the program, it seems like that the output of unfold occupies too much gpu memory, about several GB. I wonder is there any way to avoid this problem?
(some format problem happens when I want to type the square, thus I use ‘^2’ instead)

OK, I find the answer I need. As I cannot delete this topic, I just paste the answer here.