Sure! I was able to rework the code a bit and there is only one for loop now, which I think there might be a way to remove.
import torch
class SlidingWindowMask(torch.nn.Module):
def __init__(self, kernel_size: float, stride: float) -> None:
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
def forward(self, x: torch.Tensor) -> torch.Tensor:
top, left, bottom, right = self.generate_coordinates(x)
num_masks = len(top)
x_masked = torch.stack([x] * num_masks, dim=0)
for i in range(num_masks):
x_masked[i, ..., left[i] : right[i], top[i] : bottom[i]] = 0
return x_masked
def get_dimensions(self, x: torch.Tensor) -> tuple[int, int, int, int, int, int]:
*_, height, width = x.size()
mask_height = int(self.kernel_size * height)
mask_width = int(self.kernel_size * width)
stride_height = int(self.stride * height)
stride_width = int(self.stride * width)
return height, width, mask_height, mask_width, stride_height, stride_width
def generate_coordinates(
self,
x: torch.Tensor,
) -> torch.Tensor:
h, w, m_h, m_w, s_h, s_w = self.get_dimensions(x)
top, left = self.compute_top_left(h, w, m_h, m_w, s_h, s_w, device=x.device)
bottom = top + m_h
right = left + m_w
return torch.stack([top, left, bottom, right])
def compute_top_left(self, h, w, m_h, m_w, s_h, s_w, *args, **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
top = torch.arange(0, w - m_w + 1, s_h, *args, **kwargs)
left = torch.arange(0, h - m_h + 1, s_w, *args, **kwargs)
grid_x, grid_y = torch.meshgrid(left, top, indexing="ij")
return grid_y.flatten(), grid_x.flatten()
Below you have how I am testing it:
from PIL import Image
from .masking.sliding_window import SlidingWindowMask
import torchvision.transforms.functional as TF
import torchvision
import seaborn_image as isns
# img = torch.rand(16, 3, 224, 224)
img = Image.open(ROOT / "files/images/turtle.png")
img = TF.to_tensor(img)
print(f"Image shape: {img.shape}")
masking = SlidingWindowMask(kernel_size=0.25, stride=.15)
x = masking(img)
print(f"Masked image shape: {x.shape}")
x = torchvision.transforms.functional.vflip(x)
img_grid = torchvision.utils.make_grid(x, nrow=4, normalize=True, pad_value=1)
img_grid = img_grid.permute(1, 2, 0)
isns.imshow(img_grid, showticks=False, despine=True, cbar=False)
Output:
Image shape: torch.Size([4, 319, 319])
Masked image shape: torch.Size([16, 4, 319, 319])
