Hi all,
I am trying to compute the Sum of Absolute Difference (SAD) metric of an interrogation window over an image. My current implementation achieves this calculation by manually sliding the window with 2 for loops. How can I make this more efficient using PyTorch functionalities? I saw some posts suggesting torch.nn.functional.unfold(), but I can’t seem to get this to work.
def SAD(windows: torch.Tensor, images: torch.Tensor) -> torch.Tensor:
height, width = windows.shape[-2:]
num_row, num_column = images.shape[-2] - windows.shape[-2], images.shape[-1] - windows.shape[-1]
res = torch.zeros((windows.shape[0], num_row + 1, num_column + 1))
for j in range(num_row + 1):
for i in range(num_column + 1):
ref = images[:, j:j+height, i:i+width]
res[:, num_row - j, num_column - i] = torch.sum(torch.abs(windows-ref))
return res.squeeze()
Thanks!
What are the inputs to your method? unfold
sounds like the right approach.
Hi @ptrblck ,
The input to my method is
- Windows, a tensor of shape (C x H x W). It contains C windows in total of dimensions H x W
- Images, a tensor of shape (C x M x N). It contains C images (that correspond to their windows respectively). The images of dimension M x N are at least as large as the windows.
So for each window-image pair, I want to compute the SAD metric by sliding the window with respect to the image. My approach as shown above does this for all pair simulateously (all C channels). But the code is still quite slow.
I found the solution, I will share it down below
def correlate_intensity_optim(images_a: torch.Tensor, images_b: torch.Tensor) -> torch.Tensor:
height, width = images_a.shape[-2:]
num_row, num_column = images_b.shape[-2] - images_a.shape[-2], images_b.shape[-1] - images_a.shape[-1]
res = torch.zeros((images_a.shape[0], num_row + 1, num_column + 1))
for idx, (inp, ref) in enumerate(zip(images_a, images_b)):
inp, ref = inp.unsqueeze(0).unsqueeze(0).float(), ref.unsqueeze(0).unsqueeze(0).float()
unfolded = torch.nn.functional.unfold(ref, (height, width))
conv_out = unfolded.transpose(1, 2) - inp.view(inp.size(0), -1)
sad = torch.sum(torch.abs(conv_out.transpose(1, 2)), dim=1)
out = torch.nn.functional.fold(sad, (num_row+1, num_column+1), (1,1))
res[idx] = out
return res.squeeze()
The downside to this approach is that it is not faster than my first implementation, as I process all channels C sequentially (instead of in parallel). This is due to memory issues when unfolding tensors that are too big. If anybody has a workaround, please let me know!
If anybody knows how to improve the original code I wrote, please let me know! It is still very slow…
The code I use currently:
def SAD(windows: torch.Tensor, areas: torch.Tensor) -> torch.Tensor:
windows, areas = windows.float(), areas.float()
(count, window_rows, window_cols), (area_rows, area_cols) = windows.shape, areas.shape[-2:]
res = torch.zeros((count, area_rows - window_rows + 1, area_cols - window_cols + 1))
for j in range(area_rows - window_rows + 1):
for i in range(area_cols - window_cols + 1):
ref = areas[:, j:j + window_rows, i:i + window_cols]
res[:, j, i] = torch.sum(windows.sub(ref).abs(), dim=(1, 2))