I am trying to implement a 3D convolutional layer where kernels have some sampling locations completely masked out. Particularly, I want to pass a binary mask, such that locations that are set to zero do NOT contribute to the learning process.
In the example below, I am using a cross-shaped mask, then multiplying it by the convolutional kernel weights, to eliminate responses from zeroed positions (where mask == 0).
Now my question comes in two parts:
Am I achieving this correctly? Mind that my aim is to keep using this constant mask during all forward/backward passes, with no updates on the mask weights (so that at any point during training, only locations of interest are being learned).
Is there a way to also customize the bias term, such that it is only added to locations of interest?
class MaskedConv3d(nn.Module): def __init__(self, n_in, n_out, filter_mask, pad = 1): super().__init__() self.kernel_size = tuple(filter_mask.shape) self.register_buffer('filter_mask', filter_mask) self.conv = nn.Conv3d( in_channels=n_in, out_channels=n_out, kernel_size=self.kernel_size, stride=1, bias=False, padding = pad ) def forward(self, x): self._mask_conv_filter() return self.conv(x) def _mask_conv_filter(self): with torch.no_grad(): self.conv.weight = nn.Parameter(self.conv.weight * self.filter_mask) # define mask: mask = torch.tensor([[[0., 1., 0.], [1., 1., 1.], [0., 1., 0.]], [[0., 1., 0.], [1., 1., 1.], [0., 1., 0.]], [[0., 1., 0.], [1., 1., 1.], [0., 1., 0.]]]) x = torch.randn(1, 1, 6, 6, 6) maskedconv3d = MaskedConv3d(1, 8, mask) out = maskedconv3d(x)