How does one perform matrix multiplication on it's transpose in a batch?

How does one perform matrix multiplication on a matrix and it’s transpose while in a batch? And I don’t wish to loop thru the batches and perform the multiplication on each of the matrices…

I have a batch of matrices shaped:

x.shape = [64, 16, 1000]

Where

batches, k_dim, other_dim = x.shape

It seems like the answer is to multiply the matrices by using bmm

out = torch.bmm(x, x.t())

or

out = torch.matmul(x, x.t())

But both options raise errors about shape/size.

What is the proper way to do this such that I get the shape

out = [64, 16, 16]

Hello,

You can use permute to change the order of the dimensions.

    x = torch.rand(size=(64, 16, 1000), dtype=torch.float32)
    out = torch.bmm(x, x.permute(0, 2, 1))
    print(out.shape)
    # torch.Size([64, 16, 16])
1 Like