Masked Sliding Window tensor

Here is one way you could do this, but just take note that the kernel_size and stride should be equal, or the F.fold operation winds up overwriting the mask, in most cases.

#define values
img = torch.rand(16, 4, 319, 319)
kernel_size=32
stride=32

assert kernel_size==stride, "stride and kernel_size must be equal"

#create mask
mask = torch.ones_like(img)
mask = F.unfold(mask, kernel_size=kernel_size, stride=stride) # batch_size, channels*kernel*kernel, patches
N, chw, p = mask.shape
mask = mask.unsqueeze(3).expand(N, chw, p, p).clone()

diag = torch.diag(torch.ones(p)).bool()
mask[:,:,diag]=0

N, c, h, w = img.shape
mask = mask.reshape(N, c, kernel_size, kernel_size, p, p)
mask = mask.permute(0,4,1,2,3,5).reshape(N*p*c, kernel_size*kernel_size, p)
mask = F.fold(mask, (h,w), kernel_size=kernel_size, stride=stride)
mask = mask.reshape(N,p, c, h, w)

#expand img
masked_img=img.unsqueeze(1).expand(N,p,c,h,w).clone()

#apply mask
masked_img[~mask.bool()]=0.

The above steps can also be simplified to substitute the img in, instead of using a mask. However, if you make the mask once(since it’s deterministic), you can store it in memory and apply it repeatedly on new batches of images without needing to recreate it, which will be faster if this is going to be repeated.

Also, on a side note, I noticed you’re loading a png file, which means one of the channels is an alpha channel(used for transparency in png files). Not sure if you intended to keep that, but if you would like that channel removed, see here: