Masked Convolutions

I want to implement a convolutional operation that skips spatial positions of the feature maps/images conditionally on a binary mask. In other words, this operation will skip the sliding window across the feature maps of a layer if the corresponding mask for the same sliding window is all zero. One way of doing this is multiplying the outputs with respect to the binary mask. However, I prefer to skip convolutions on the pixels of the input layer where the pixel in the corresponding binary mask is 0. You can find a rough pseudo-code of this operation below.

for rw in range(n_rows):
for cl in range(n_cols):
for cn in range(n_channels):
if mask(rw, cl) == 0:
conv(img(rw-1:rw+1, cl-1:cl+1, :), filter)

I will appreciate it if you can point me to a solution in PyTorch.

Have you checked the masked conv implement in mmdetection:

The above masked_conv can work, but the speed is very slow compared with normal nn.Conv2d.