Applying custom mask on kernel for CNN

Is this the correct way to specify custom manipulation of the weights of a convolution layer?

class MaskedConv3d(nn.Module):
    def __init__(self, channels, filter_mask):
     
        super().__init__()
        self.kernel_size = tuple(filter_mask.shape)
        self.filter_mask = nn.Parameter(filter_mask) #  tensor
        self.conv = nn.Conv3d(
            in_channels=channels,
            out_channels=channels,
            kernel_size=self.kernel_size,
        )

def forward(self, x):
        self._mask_conv_filter()
        return self.conv(x)

def _mask_conv_filter(self):
        with torch.no_grad():
            self.conv.weight.data = self.conv.weight.data * self.filter_mask

Specifically, in the last line of code I’m using .data attribute and not the tensors themselves since otherwise I’m getting the following error:

TypeError: cannot assign 'torch.FloatTensor' as parameter 'weight' (torch.nn.Parameter or None expected)

Original code:

        with torch.no_grad():
            self.conv.weight = self.conv.weight * self.filter_mask

Thanks
Barak

No, you shouldn’t use the .data attribute, as it might yield silent errors and could break your code in various ways.

The error message points to a mismatch between a tensor and the expected nn.Parameter.
Try to wrap the new weight into a parameter via:

with torch.no_grad():
    self.conv.weight = nn.Parameter(self.conv.weight * self.filter_mask)

Also, since self.filter_mask is used in a no_grad() block only, I assume it won’t be trained and can thus be registered as a buffer via:

self.register_buffer('filter_mask', filter_mask)
1 Like

Thanks @ptrblck for your prompt answer.

This actually broke the backward pass for me, seems like the weight matrix is kept the same after backward is called, maybe something changed between pytorch versions?

The gradient should not be affected by the masking since it was applied inside the stop_grad context manager. @BBLN

I meant the gradient for the weights themselves stopped working at all, maybe because the parameter object being changed in a call to the module forward pass.