Hello,
You can use permute to change the order of the dimensions.
x = torch.rand(size=(64, 16, 1000), dtype=torch.float32)
out = torch.bmm(x, x.permute(0, 2, 1))
print(out.shape)
# torch.Size([64, 16, 16])
Hello,
You can use permute to change the order of the dimensions.
x = torch.rand(size=(64, 16, 1000), dtype=torch.float32)
out = torch.bmm(x, x.permute(0, 2, 1))
print(out.shape)
# torch.Size([64, 16, 16])