I want to find cosine distance between each pair of 2 tensors.
That is given [a,b] and [p,q], I want a 2x2 matrix which finds
[ cosDist(a,p), cosDist(a,q)
cosDist(b,p), cosDist(b,q) ]
I want to be able to use this matrix for triplet loss with hard mining.
What is the best way to do this?
Thanks @InnovArul, I had been referring to this code. I was wondering if it is possible to use nn.CosineSimilarity() instead of computing the cosine similarity manually, just to be sure that there are no errors by me.
def cosine_distance_torch(x1, x2=None, eps=1e-8):
x2 = x1 if x2 is None else x2
w1 = x1.norm(p=2, dim=1, keepdim=True)
w2 = w1 if x2 is x1 else x2.norm(p=2, dim=1, keepdim=True)
return 1 - torch.mm(x1, x2.t()) / (w1 * w2.t()).clamp(min=eps)
def cosine_similarity_n_space(m1=None, m2=None, dist_batch_size=100):
NoneType = type(None)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if type(m1) != torch.Tensor: # only numpy conversion supported
m1 = torch.from_numpy(m1).float()
if type(m2) != torch.Tensor and type(m2)!=NoneType:
m2 = torch.from_numpy(m2).float() # m2 could be None
m2 = m1 if m2 is None else m2
assert m1.shape == m2.shape
result = torch.zeros([1, m2.shape])
for row_i in range(0, int(m1.shape / dist_batch_size) + 1):
start = row_i * dist_batch_size
end = min([(row_i + 1) * dist_batch_size, m1.shape])
if end <= start:
break # cause I'm too lazy to elegantly handle edge cases
rows = m1[start: end]
# sim = cosine_similarity(rows, m2) # rows is O(1) size
sim = cosine_distance_torch(rows.to(device), m2.to(device))
result = torch.cat( (result, sim.cpu()), 0)
result = result[1:, :] # deleting the first row, as it was used for setting the size only
return result.numpy() # return 1 - ret # should be used with sklearn cosine_similarity
You can use this snippet. Cosine similarity is the same as the scalar product of the normalized inputs and you can get the pw scalar product through matrix multiplication.
Cosine distance in turn is just 1-cosine_similarity.
def pw_cosine_distance(input_a, input_b):
normalized_input_a = torch.nn.functional.normalize(input_a)
normalized_input_b = torch.nn.functional.normalize(input_b)
res = torch.mm(normalized_input_a, normalized_input_b.T)
res *= -1 # 1-res without copy
res += 1