Confused about torch.bmm

hello, PyTorch team! I was confused about torch.bmm . In my code below:
image

query_states and key_states are both 4-D tensors

image

As the docs say: torch.bmm(input , mat2 , *** , out=None ) → [Tensor]: input and mat2 must be 3-D tensors each containing the same number of matrices.

So, why torch.bmm doesn’t occur en error here? My pytorch version is ‘2.1.0+cpu’. Thanks!

Do you have a short reproducible snippet?

>>> a = torch.rand(1, 40, 3, 128)
>>> b = torch.rand(1, 40, 3, 128)
>>> torch.bmm(a, b.transpose(2, 3))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: batch1 must be a 3D tensor

Also, could you try again on a later PyTorch version?

No, I can’t reproduce it in python terminal, which will report an error: batch1 must be a 3D tensor. Ok, I will use torch.matmul instead and I’m just feel confused about this.