Custom 2D kernel: set center value to zero and not updating while training

I have the following model where I perform a dilated conv2d operation with a 3x3 kernel, as shown in the code with self.cnn4_dialation. I want to perform this dilated convolutional in a way that the center value of the kernel is always to be zero i.e. will not update while training and other values will get updated during the training.

Would anyone please tell me how to do this?

custom kernel format:

a b c
d 0 e
f g h

where a, b, c, d, e, f, g, h are being updated while training.

The defined model:

class base_Model(nn.Module):
    def __init__(self):
        super(base_Model,self).__init__()
        self.cnn1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.cnn2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.cnn3 = nn.Conv2d(in_channels=32, out_channels=100, kernel_size=3, stride=1, padding=1, dilation=1)
        self.relu3 = nn.ReLU()
        self.cnn4_dialation = nn.Conv2d(in_channels=100, out_channels=100, kernel_size=3, stride=1, padding=5, dilation=5)
        self.relu4 = nn.ReLU()
        self.cnn5_1x1 = nn.Conv2d(in_channels=100, out_channels=100, kernel_size=1, stride=1, padding=0)
    def forward(self, x):
        top_layers = self.relu3(self.cnn3(self.maxpool1(self.relu2(self.cnn2(self.relu1(self.cnn1(x)))))))
        out_dialated = self.cnn5_1x1(self.relu4(self.cnn4_dialation(top_layers)))
        return out_dialated

Hi,

You can very easily set the gradient for this entry to 0 at every backward pass by simply adding a hook on the weights like:

def hook_fn(grad):
    res = grad.clone() # You are not allowed to change the input
    res[1][1] = 0 # assumes 3x3 kernel
model.cnn1.weight.register_hook(hook_fn) # For the first layer of your model above.

The limitation of this approach is that optimizers like Adam or SGD with momentum will update parameters even if the gradient is exactely 0.

Another approach can be to manually set it back to 0 before the forward

def pre_fw_hook(mod, *args):
    with torch.no_grad(): # Don't track this in gradient computations
        mod.weight[1][1] = 0 # Assumes 3x3 kernel
model.cnn1.register_pre_forward_hook(pre_fw_hook)

But then the weights you read after the weight update might not contain the 0 and you need to be careful to register this hook backward again after loading the model.

1 Like

Thank you for your reply. I had a mistake in the question, that, I did not explicitly mention the self.cnn4_dialation layer, which is actually having the custom kernel. The shape of the self.cnn4_dialation is [100, 100, 3, 3].

I would like to go with the forward hook. However, can I do this manually setting to 0 in the def forward(self, x): definition as bellow:

    def forward(self, x):
        with torch.no_grad():
            self.cnn4_dialation.weight[:, :, 1, 1] = 0
        top_layers = self.relu3(self.cnn3(self.maxpool1(self.relu2(self.cnn2(self.relu1(self.cnn1(x)))))))
        out_dialated = self.cnn5_1x1(self.relu4(self.cnn4_dialation(top_layers)))
        return out_dialated

Will that work as same as the forward hook?

Hi,

Yes that works fine!
For some reason I had in mind that you were using a Sequential and not having access to the forward function :smiley: But yes changing this in the forward function is even better !