hello, PyTorch team! I was confused about torch.bmm
. In my code below:

query_states
and key_states
are both 4-D tensors

As the docs say: torch.bmm(input , mat2 , *** , out=None ) → [Tensor]: input
and mat2
must be 3-D tensors each containing the same number of matrices.
So, why torch.bmm
doesn’t occur en error here? My pytorch version is ‘2.1.0+cpu’. Thanks!
Do you have a short reproducible snippet?
>>> a = torch.rand(1, 40, 3, 128)
>>> b = torch.rand(1, 40, 3, 128)
>>> torch.bmm(a, b.transpose(2, 3))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: batch1 must be a 3D tensor
Also, could you try again on a later PyTorch version?
No, I can’t reproduce it in python terminal, which will report an error: batch1 must be a 3D tensor. Ok, I will use torch.matmul
instead and I’m just feel confused about this.