Normalized convolution layer

I would like to create a custom convolution layer in which, after each back-propagation and the weight’s update, I need to transfer the weight of my kernel matrix to a specific format: the central value must be -1, and the sum of the rest must be 1.

Screenshot 2022-12-24 at 07.54.52

This might work.

import torch
import torch.nn.functional as F

#define images and conv2d output channels
images=torch.rand((5, 3, 16, 16))
out_channels=3

#define custom kernel
kernel=torch.tensor([[[[0, 1, 0], [1, -1, 1], [0, 1, 0]]]]).float()

#repeat kernel for correct input channels and output channels
kernel=torch.cat([kernel]*images.size()[1], dim=1)
kernel=torch.cat([kernel]*out_channels, dim=0)

#pass through conv2d
output=F.conv2d(images, kernel, padding=1)

But the above won’t involve updates. This would be how to apply a kernel in conv2d. You would need to write some function to update it according to your requirements.

In this case, you are referring to the kernel weight initialization. These constraints should be applied during training. Below is an implementation that I am not certain is correct; if anyone has any comments, please let me know.

class Constlayer(nn.Module):
    """
    doc
    """
    def __init__(self, numch: int=128, centralval=1.0):
        super().__init__()
        self.register_parameter(name="const_weight", param=None)
        self.numch = numch
        self.centralval = centralval
        self.const_weight = nn.Parameter(torch.randn(size=[numch, 1, 5, 5]), requires_grad=True)

    def normalize(self):
        with torch.no_grad():
            centeral_pixel = self.const_weight[:, 0, 2, 2]
            for i in range(self.numch):
                kernel_sum = torch.sum(self.const_weight[i]) - centeral_pixel[i]
                self.const_weight[i] = (self.const_weight[i]/kernel_sum)*self.centralval
                self.const_weight[i, 0, 2, 2] = -self.centralval

    def forward(self, x):
        self.normalize()
        out = F.conv2d(x, self.const_weight)
        return out