Hello Pytorchers!
I am trying to implement a 3D convolutional layer where kernels have some sampling locations completely masked out. Particularly, I want to pass a binary mask, such that locations that are set to zero do NOT contribute to the learning process.
In the example below, I am using a cross-shaped mask, then multiplying it by the convolutional kernel weights, to eliminate responses from zeroed positions (where mask == 0).
Now my question comes in two parts:
-
Am I achieving this correctly? Mind that my aim is to keep using this constant mask during all forward/backward passes, with no updates on the mask weights (so that at any point during training, only locations of interest are being learned).
-
Is there a way to also customize the bias term, such that it is only added to locations of interest?
class MaskedConv3d(nn.Module):
def __init__(self, n_in, n_out, filter_mask, pad = 1):
super().__init__()
self.kernel_size = tuple(filter_mask.shape)
self.register_buffer('filter_mask', filter_mask)
self.conv = nn.Conv3d(
in_channels=n_in,
out_channels=n_out,
kernel_size=self.kernel_size,
stride=1,
bias=False,
padding = pad
)
def forward(self, x):
self._mask_conv_filter()
return self.conv(x)
def _mask_conv_filter(self):
with torch.no_grad():
self.conv.weight = nn.Parameter(self.conv.weight * self.filter_mask)
# define mask:
mask = torch.tensor([[[0., 1., 0.],
[1., 1., 1.],
[0., 1., 0.]],
[[0., 1., 0.],
[1., 1., 1.],
[0., 1., 0.]],
[[0., 1., 0.],
[1., 1., 1.],
[0., 1., 0.]]])
x = torch.randn(1, 1, 6, 6, 6)
maskedconv3d = MaskedConv3d(1, 8, mask)
out = maskedconv3d(x)