Need help to estimate output shape of a torch.matmul operation

Hi there! Thanks for reading.

I have to do a torch.matmul(a,b) between two tensors a and b in which:

a.size() = torch.Size([8, 1, 32, 256, 256])


b.size() = torch.Size([8, 1, 32, 256, 256])

c = matmul(a,b)

Im expecting to have c.size() = orch.Size([8, 1, 1, 256, 256])

In other words, I want, for each pixel of the 256x256, multiply the two 32-dimensional vector at that position (x,y) in a and b.
How could I get this?

Thank you so much!


Does this work (a * b).sum(2, keepdim=True)?

1 Like

In deed, it worked! Thank you so much