Block Matching algorithm?

Hi, I’m trying to implement the block matching algorithm using torch.

The basic idea is to take a patch say 8 by 8, then define a search window say 50 by 50 and we need to find the top 10 most similar patches inside that search window.

I had idea to search for top 10 most similar patches in the whole image using unfold + calculate distance + sort function but not quite sure how to do if there is a search window.

Thanks!

This might be of some help:

you mean this might help with the distance step?

Yes. This package may help to match patches between feature maps.
I assumed the matching should happen at every pixel, where you consider the 8x8 patch at every pixel position and match it within the 50x50 window around it. Maybe this is not the case?

Not sure of your use case. If its matching one 8x8 patch with a particular 50x50 window, you could achieve it by convolving (F.conv2d?) this 8x8 patch over the 50x50 window, then sort the responses.

yeah right! the matching should happen at every pixel.

Let say there is a 256x256 image, for each pixel of that image, take a patch of 8x8 around it and start looking for similar patches inside a search window of 50x50.
2 nested for loops might work, not sure if there is a better solution using vector or tensor…

Then my assumption was correct. Thanks for clarifying.
The Pytorch-Correlation-extension package I linked above provides the necessary functionality.
You can look at the examples provided in the Github and see if it works for you.

Here is one way you could approach the problem in a parallelized way. This assumes you want to compare “patches” with l1loss(though you could easily substitute MSE).

import torch
import torch.nn.functional as F

def l1loss(patch1, patch2, dim=1):
    return torch.mean(torch.abs((patch1-patch2)), dim=dim)

def get_patches(images, kernel_size=(8,8)):
    return F.unfold(images, kernel_size) # output returns size (batch, flattened patch, patches)

images=torch.randn((1,1,50, 50))

patches=get_patches(images)
b, hw, p = patches.shape
patches_exp=patches.unsqueeze(3).expand(b, hw, p, p)

z=torch.triu_indices(p,p,1) #mapping of the triangular upper matrix

losses = l1loss(patches_exp.triu(), patches_exp.rot90(k=1, dims=[2,3]).triu())
losses=losses[losses!=0]
print(losses.size()) #size should be p*((p-1)/2) which represents non-zeros for triu when diagonals are zero
values, indices = torch.topk(losses, k=10, largest=False) # top 10 values and their indices
print(values, indices)
index=0
x_val=z[0][indices[index]] # get the indices from the triu_indices mapping at the selected topk index
y_val=z[1][indices[index]]

print(x_val, y_val, values[index]) #check the indices and value
print(l1loss(patches[:,:,x_val],patches[:,:,y_val])) #check that the loss matches when the indices are applied to the original unfolded patches

.triu() is used to eliminate duplicate calculations or getting the losses between the same patch.

Updated to correct an error and include a usage example with topk().

https://pytorch.org/docs/stable/generated/torch.topk.html

Note that with larger images, you’re going to run into some major memory problems and may need to split up the operation into an iterable.