Find the matrix multiplication of every matrix with another in a batch

I have an input of the shape:

N * F * C * W * W,

which is basically F feature maps for a single example in a batch. I want the matrix product of each feature map with another in a single example. The output dimension should be

N * F * F * C * W * W

using torch.matmul I cannot model the interaction between every pair of feature maps in an example and the output would be N * F * C * W * W.

I can use a loop to do so, but that would be inefficient, is there any other way to do so?

I’m not sure if I misunderstand the use case, but wouldn’t this matmul work?

N, F, C, W = 2, 3, 4, 5
x = torch.randn(N, F, C, W, W)
y = torch.matmul(x.unsqueeze(2), x.unsqueeze(1))
print(y.shape)
> torch.Size([2, 3, 3, 4, 5, 5])