How to avoild for loop while using torch.nn.CosineSimilarity()?

I have two tensors, shapes are shown below.

input1 = torch.randn(1, 3, 10, 8) 
input2 = torch.randn(1, 1, 128, 8)

I would like to do following operation:

cos = nn.CosineSimilarity(dim=2, eps=1e-6)
output = cos(input1, input2)

But I get an error saying that dim 2 has to be the same

Therefore, I had to use for loop as an alternative solution as below

cos = nn.CosineSimilarity(dim=2, eps=1e-6)
output = []
for idx in range(0,10):
     output.append(cos(input1.[:, :, idx, :].unsqueeze(0).unsqueeze(1), input2))

This solution works fine. But is it possible to do the same operation without using for loop?

Any hints are helpful.

Thanks!

output = [cos(input1.[:, :, idx, :].unsqueeze(0).unsqueeze(1), input2)) for idx in range(10)] should work.

However I am not sure it makes sense to compute the cosine similarity of a subset of the layer.

Thanks for response. Your solution is still the same as mine, right? :slight_smile: just you use list comprehension.
Cosine similarity is not for layer outputs, but just for two different vectors.

Can you say what is the output shape that you want?

input1 = torch.randn(1, 3, 10, 8) 
input2 = torch.randn(1, 1, 128, 8)
cos = nn.CosineSimilarity(dim=2, eps=1e-6)
output = cos(input1, input2)

Expected output shape is (1, 3, 10*128, 1) => (1, 3, 1280, 1)

Its like computing cosine similarity for pairwise vectors of inputs

@aswamy
``
in1 = input1.squeeze()
in2 = input2.squeeze()

in11 = F.normalize(in1, dim=1)
in22 = F.normalize(in2, dim=0)

torch.cdist(in11, in22, p=2)


The output shape would be `(3,10,128)` where the `(10, 128)` would be your 1280 cosine similarity values that you need. Unsqueeze appropriate dims to get the desired shape.

The logic behind this is that,

 `l2_distance(l2_normalized_vector_1, l2_normalized_vector_2) = k*CosineSimilarity(UnNormalized_vector_1, UnNormalized_vector_2)`

where `k` is a constant.

Refer [this](https://stats.stackexchange.com/questions/146221/is-cosine-similarity-identical-to-l2-normalized-euclidean-distance) for more on it.