I am trying to understand this paper through the oficial code implementation. The learn prototypes and compute the distance input image feature patches and those prototypes. In the code, however, this distance is implemented as the following (refer to model.py, lines 154 to 177):
def _l2_convolution(self, x):
'''
apply self.prototype_vectors as l2-convolution filters on input x
'''
# x.shape: torch.Size([80, 128, 7, 7])
pdb.set_trace()
x2 = x ** 2 # why do they do this? : x2.shape: torch.Size([80, 128, 7, 7])
x2_patch_sum = F.conv2d(input=x2, weight=self.ones) # torch.Size([80, 2000, 7, 7])
# self.prototype_vectors shape: torch.Size([2000, 128, 1, 1]) - random values
p2 = self.prototype_vectors ** 2
# compute the squared L2 norm of the prototype_vectors
p2 = torch.sum(p2, dim=(1, 2, 3)) #torch.Size([2000])
# p2 is a vector of shape (num_prototypes,)
# then we reshape it to (num_prototypes, 1, 1)
p2_reshape = p2.view(-1, 1, 1) #torch.Size([2000, 1, 1])
xp = F.conv2d(input=x, weight=self.prototype_vectors) #torch.Size([80, 2000, 7, 7])
intermediate_result = - 2 * xp + p2_reshape # use broadcast
# x2_patch_sum and intermediate_result are of the same shape
distances = F.relu(x2_patch_sum + intermediate_result)
return distances
I am not sure how this is computing the distance! To me this line xp = F.conv2d(input=x, weight=self.prototype_vectors)
computes the similarity between patches of the features and the initialized prototypes (in the very first step). I want to get a second opinion on this. I think I’m missing something. Thank you in advance