The code below penalizes the cosine similarity between different tensors in batch, but PyTorch has a dedicated CosineSimilarity class that I think might make this code less complex and more efficient. Would it possible to do the same with torch.nn.CosineSimilarity
, and if so how?
batch = input.size(0)
flattened = input.view(batch, input.size(1), -1)
grams = torch.matmul(flattened, torch.transpose(flattened, 1, 2))
grams = F.normalize(grams, p=2, dim=(1, 2), eps=1e-10)
loss = -sum([ sum([ (grams[i]*grams[j]).sum()
for j in range(batch) if j != i])
for i in range(batch)]) / batch
I’ve tried variations of stuff like this, but I can’t get the same result as the above code:
list2 = []
for i in range(input.size(0)):
list1 = []
for j in range(input.size(0)):
similarity = sum(torch.cosine_similarity(input[i].view(1,-1), input[j].view(1,-1)))
list1.append(similarity)
list2.append(sum(list1))
loss = -sum(list2)