Here is one way you could approach the problem in a parallelized way. This assumes you want to compare “patches” with l1loss(though you could easily substitute MSE).
import torch
import torch.nn.functional as F
def l1loss(patch1, patch2, dim=1):
return torch.mean(torch.abs((patch1-patch2)), dim=dim)
def get_patches(images, kernel_size=(8,8)):
return F.unfold(images, kernel_size) # output returns size (batch, flattened patch, patches)
images=torch.randn((1,1,50, 50))
patches=get_patches(images)
b, hw, p = patches.shape
patches_exp=patches.unsqueeze(3).expand(b, hw, p, p)
z=torch.triu_indices(p,p,1) #mapping of the triangular upper matrix
losses = l1loss(patches_exp.triu(), patches_exp.rot90(k=1, dims=[2,3]).triu())
losses=losses[losses!=0]
print(losses.size()) #size should be p*((p-1)/2) which represents non-zeros for triu when diagonals are zero
values, indices = torch.topk(losses, k=10, largest=False) # top 10 values and their indices
print(values, indices)
index=0
x_val=z[0][indices[index]] # get the indices from the triu_indices mapping at the selected topk index
y_val=z[1][indices[index]]
print(x_val, y_val, values[index]) #check the indices and value
print(l1loss(patches[:,:,x_val],patches[:,:,y_val])) #check that the loss matches when the indices are applied to the original unfolded patches
.triu()
is used to eliminate duplicate calculations or getting the losses between the same patch.
Updated to correct an error and include a usage example with topk()
.
https://pytorch.org/docs/stable/generated/torch.topk.html
Note that with larger images, you’re going to run into some major memory problems and may need to split up the operation into an iterable.