Sure! I was able to rework the code a bit and there is only one for loop now, which I think there might be a way to remove.
import torch
class SlidingWindowMask(torch.nn.Module):
    def __init__(self, kernel_size: float, stride: float) -> None:
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        top, left, bottom, right = self.generate_coordinates(x)
        num_masks = len(top)
        x_masked = torch.stack([x] * num_masks, dim=0)
        for i in range(num_masks):
            x_masked[i, ..., left[i] : right[i], top[i] : bottom[i]] = 0
        return x_masked
    def get_dimensions(self, x: torch.Tensor) -> tuple[int, int, int, int, int, int]:
        *_, height, width = x.size()
        mask_height = int(self.kernel_size * height)
        mask_width = int(self.kernel_size * width)
        stride_height = int(self.stride * height)
        stride_width = int(self.stride * width)
        return height, width, mask_height, mask_width, stride_height, stride_width
    
    def generate_coordinates(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        h, w, m_h, m_w, s_h, s_w = self.get_dimensions(x)
        top, left = self.compute_top_left(h, w, m_h, m_w, s_h, s_w, device=x.device)
        
        bottom = top + m_h
        right = left + m_w
        return torch.stack([top, left, bottom, right])
    def compute_top_left(self, h, w, m_h, m_w, s_h, s_w, *args, **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
        top = torch.arange(0, w - m_w + 1, s_h, *args, **kwargs)
        left = torch.arange(0, h - m_h + 1, s_w, *args, **kwargs)
        grid_x, grid_y = torch.meshgrid(left, top, indexing="ij")
        return grid_y.flatten(), grid_x.flatten()
Below you have how I am testing it:
from PIL import Image
from .masking.sliding_window import SlidingWindowMask
import torchvision.transforms.functional as TF
import torchvision
import seaborn_image as isns
# img = torch.rand(16, 3, 224, 224)
img = Image.open(ROOT / "files/images/turtle.png")
img = TF.to_tensor(img)
print(f"Image shape: {img.shape}")
masking = SlidingWindowMask(kernel_size=0.25, stride=.15)
x = masking(img)
print(f"Masked image shape: {x.shape}")
x = torchvision.transforms.functional.vflip(x)
img_grid = torchvision.utils.make_grid(x, nrow=4, normalize=True, pad_value=1)
img_grid = img_grid.permute(1, 2, 0)
isns.imshow(img_grid, showticks=False, despine=True, cbar=False)
Output:
Image shape: torch.Size([4, 319, 319])
Masked image shape: torch.Size([16, 4, 319, 319])
