How to create a trainable customized convolutional layer

I am wondering how to customize a convolutional layer, where the convolutional filter is determined by some functions instead of independent parameters. I made a toy example below. a and b are parameters. A 2X2 filter is created as torch.tensor([[[[a+b, a-b], [a-b,a+b]]]]). As we can see, it is determined by a and b instead of four independent parameters. In my experiment, the network is simply composed of one such filter. I made a simple learning task to test my code, where there is only one training instance: input: torch.tensor([[[[1.0, 0],[0,0]]]]), true output: torch.tensor([[2.0]]). I attach my code and experiment below.

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # in my later experiment there might be multiple parameters,
        # so I made the generate_parameters() function
        a, b = self.generate_parameters()
        self.conv = nn.Conv2d(1, 1, 2, bias=False)

        kernel = self.generate_kernel(a, b)
        self.conv.weight = nn.Parameter(kernel)
        
    def forward(self, x):
        x = self.conv(x)
        return x
    
    def generate_parameters(self):
        a = nn.Parameter(torch.zeros(1, 1))
        b = nn.Parameter(torch.zeros(1, 1))
        return a, b
    
    def generate_kernel(self, a, b):
        return torch.tensor([[[[a+b, a-b], [a-b,a+b]]]])

model = MyModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

inputs, labels = torch.tensor([[[[1.0, 0],[0,0]]]]), torch.tensor([[2.0]])

for epoch in range(10):  # loop over the dataset multiple times
    for param in model.parameters():
        print(param.data, param.size())

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

experiment result:

epoch: 0 tensor([[[[0., 0.],
          [0., 0.]]]]) torch.Size([1, 1, 2, 2])
epoch: 1 tensor([[[[0.4000, 0.0000],
          [0.0000, 0.0000]]]]) torch.Size([1, 1, 2, 2])
epoch: 2 tensor([[[[0.7200, 0.0000],
          [0.0000, 0.0000]]]]) torch.Size([1, 1, 2, 2])
epoch: 3 tensor([[[[0.9760, 0.0000],
          [0.0000, 0.0000]]]]) torch.Size([1, 1, 2, 2])
epoch: 4 tensor([[[[1.1808, 0.0000],
          [0.0000, 0.0000]]]]) torch.Size([1, 1, 2, 2])
epoch: 5 tensor([[[[1.3446, 0.0000],
          [0.0000, 0.0000]]]]) torch.Size([1, 1, 2, 2])
epoch: 6 tensor([[[[1.4757, 0.0000],
          [0.0000, 0.0000]]]]) torch.Size([1, 1, 2, 2])
epoch: 7 tensor([[[[1.5806, 0.0000],
          [0.0000, 0.0000]]]]) torch.Size([1, 1, 2, 2])
epoch: 8 tensor([[[[1.6645, 0.0000],
          [0.0000, 0.0000]]]]) torch.Size([1, 1, 2, 2])
epoch: 9 tensor([[[[1.7316, 0.0000],
          [0.0000, 0.0000]]]]) torch.Size([1, 1, 2, 2])

As we can see, the trained filter isn’t as we expect (which should have been a symmetric 2X2 matrix). Ideally I thought the model should have only two parameters a and b. Any advice to fix my code or experiment? Thank you!

You are wrapping the kernel into a new nn.Parameter so that all values will get independent gradients and will thus diverge.
To use a and b as the only parameters, you could create the filter using torch.cat and torch.stack operations and use it in the functional API via F.conv2d(x, kernel).