Masked Sliding Window tensor

Hello all,

My problem should be simple, but most solutions I found use a for loop, which I think is making it too slow for my case.

Given an image batch, I want to create a masked batch which masks parts in a sliding window fashion. I tried it with unfold and fold, but I still cannot understand how to do it in a simple way.

For example, given an image batch of shape (16, 3, 224, 224) and a sliding window with kernel size 32 and stride 16. I want a new tensor with shape (16, 169, 3, 224, 224). Each of the 169 images having a different mask.

Could anybody help? Thanks!

Can you write the complete code that gives the result you want as an iterable? You can substitute img = torch.rand(16, 3, 224, 224) as the input.

1 Like

Sure! I was able to rework the code a bit and there is only one for loop now, which I think there might be a way to remove.

import torch

class SlidingWindowMask(torch.nn.Module):

    def __init__(self, kernel_size: float, stride: float) -> None:
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        top, left, bottom, right = self.generate_coordinates(x)
        num_masks = len(top)

        x_masked = torch.stack([x] * num_masks, dim=0)
        for i in range(num_masks):
            x_masked[i, ..., left[i] : right[i], top[i] : bottom[i]] = 0

        return x_masked

    def get_dimensions(self, x: torch.Tensor) -> tuple[int, int, int, int, int, int]:
        *_, height, width = x.size()

        mask_height = int(self.kernel_size * height)
        mask_width = int(self.kernel_size * width)

        stride_height = int(self.stride * height)
        stride_width = int(self.stride * width)

        return height, width, mask_height, mask_width, stride_height, stride_width
    
    def generate_coordinates(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        h, w, m_h, m_w, s_h, s_w = self.get_dimensions(x)

        top, left = self.compute_top_left(h, w, m_h, m_w, s_h, s_w, device=x.device)
        
        bottom = top + m_h
        right = left + m_w

        return torch.stack([top, left, bottom, right])

    def compute_top_left(self, h, w, m_h, m_w, s_h, s_w, *args, **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
        top = torch.arange(0, w - m_w + 1, s_h, *args, **kwargs)
        left = torch.arange(0, h - m_h + 1, s_w, *args, **kwargs)

        grid_x, grid_y = torch.meshgrid(left, top, indexing="ij")

        return grid_y.flatten(), grid_x.flatten()

Below you have how I am testing it:

from PIL import Image

from .masking.sliding_window import SlidingWindowMask

import torchvision.transforms.functional as TF

import torchvision

import seaborn_image as isns

# img = torch.rand(16, 3, 224, 224)

img = Image.open(ROOT / "files/images/turtle.png")

img = TF.to_tensor(img)

print(f"Image shape: {img.shape}")

masking = SlidingWindowMask(kernel_size=0.25, stride=.15)

x = masking(img)

print(f"Masked image shape: {x.shape}")

x = torchvision.transforms.functional.vflip(x)

img_grid = torchvision.utils.make_grid(x, nrow=4, normalize=True, pad_value=1)

img_grid = img_grid.permute(1, 2, 0)

isns.imshow(img_grid, showticks=False, despine=True, cbar=False)

Output:

Image shape: torch.Size([4, 319, 319])
Masked image shape: torch.Size([16, 4, 319, 319])

sliding_window_mask

Here is one way you could do this, but just take note that the kernel_size and stride should be equal, or the F.fold operation winds up overwriting the mask, in most cases.

#define values
img = torch.rand(16, 4, 319, 319)
kernel_size=32
stride=32

assert kernel_size==stride, "stride and kernel_size must be equal"

#create mask
mask = torch.ones_like(img)
mask = F.unfold(mask, kernel_size=kernel_size, stride=stride) # batch_size, channels*kernel*kernel, patches
N, chw, p = mask.shape
mask = mask.unsqueeze(3).expand(N, chw, p, p).clone()

diag = torch.diag(torch.ones(p)).bool()
mask[:,:,diag]=0

N, c, h, w = img.shape
mask = mask.reshape(N, c, kernel_size, kernel_size, p, p)
mask = mask.permute(0,4,1,2,3,5).reshape(N*p*c, kernel_size*kernel_size, p)
mask = F.fold(mask, (h,w), kernel_size=kernel_size, stride=stride)
mask = mask.reshape(N,p, c, h, w)

#expand img
masked_img=img.unsqueeze(1).expand(N,p,c,h,w).clone()

#apply mask
masked_img[~mask.bool()]=0.

The above steps can also be simplified to substitute the img in, instead of using a mask. However, if you make the mask once(since it’s deterministic), you can store it in memory and apply it repeatedly on new batches of images without needing to recreate it, which will be faster if this is going to be repeated.

Also, on a side note, I noticed you’re loading a png file, which means one of the channels is an alpha channel(used for transparency in png files). Not sure if you intended to keep that, but if you would like that channel removed, see here: