I want to change norm distance to cosine distance, help me convert this function to cosine distance
def feat_prototype_distance(self, feat):
N, C, H, W = feat.shape
feat_proto_distance = -torch.ones((N, self.class_numbers, H, W)).to(feat.device)
for i in range(self.class_numbers):
feat_proto_distance[:, i, :, :] = torch.norm(self.objective_vectors[i].reshape(-1,1,1).expand(-1, H, W) - feat, 2, dim=1)
return feat_proto_distance
This is original function using norm distance with shapes:
self.objective_vectors[i].reshape(-1,1,1).expand(-1, H, W): torch.Size([256, 128, 224])
feat: torch.Size([8, 256, 128, 224]) with 8 is batch_size