How to penalize cosine similarity using torch.nn.CosineSimilarity?

The code below penalizes the cosine similarity between different tensors in batch, but PyTorch has a dedicated CosineSimilarity class that I think might make this code less complex and more efficient. Would it possible to do the same with torch.nn.CosineSimilarity, and if so how?

batch = input.size(0)
flattened = input.view(batch, input.size(1), -1)
grams = torch.matmul(flattened, torch.transpose(flattened, 1, 2))
grams = F.normalize(grams, p=2, dim=(1, 2), eps=1e-10)
loss = -sum([ sum([ (grams[i]*grams[j]).sum()
           for j in range(batch) if j != i])
           for i in range(batch)]) / batch

I’ve tried variations of stuff like this, but I can’t get the same result as the above code:

    list2 = []
    for i in range(input.size(0)):
        list1 = []
        for j in range(input.size(0)):
            similarity = sum(torch.cosine_similarity(input[i].view(1,-1), input[j].view(1,-1)))
            list1.append(similarity)
        list2.append(sum(list1))
    loss = -sum(list2)

I think I’ve got working with this:

flattened = input.view(input.size(0), input.size(1), -1)
grams = torch.matmul(flattened, torch.transpose(flattened, 1, 2))
list2 = []
for i in range(input.size(0)):
    list1 = []
    for j in range(input.size(0)):
        if j!=i:
            similarity = sum(torch.cosine_similarity(grams[i].view(1,-1), grams[j].view(1,-1)))
            list1.append(similarity)
    list2.append(sum(list1))
loss = -sum(list2) / input.size(0)

Or the last bit rewritten in list comprehension:

loss = -sum([ sum([(torch.cosine_similarity(grams[j].view(1,-1), grams[i].view(1,-1))).sum() for i in range(input.size(0)) if i != j]) for j in range(input.size(0))]) / input.size(0)