How does one perform matrix multiplication on a matrix and it’s transpose while in a batch? And I don’t wish to loop thru the batches and perform the multiplication on each of the matrices…
I have a batch of matrices shaped:
x.shape = [64, 16, 1000]
Where
batches, k_dim, other_dim = x.shape
It seems like the answer is to multiply the matrices by using bmm
out = torch.bmm(x, x.t())
or
out = torch.matmul(x, x.t())
But both options raise errors about shape/size.
What is the proper way to do this such that I get the shape
out = [64, 16, 16]