I have pasted my code snippet here:
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1. / sqrt(E)
# scores = torch.einsum("blhe,bshe->bhls", queries, keys)
scores = torch.matmul(queries, keys)
I am trying to use matmul instead of einsum, but I get the following error
scores = torch.matmul(keys,queries)
RuntimeError: Expected batch2_sizes[0] == bs && batch2_sizes[1] == contraction_size to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
Shapes of the arguments are as follows:
shape of queries is torch.Size([1024, 6, 8, 4])
shape of keys is torch.Size([1024, 6, 8, 4])
Could you please help me with this error. I know the my matrix sizes are incompatible for matrix multiplication, but could you please help me how do I use permute or transpose so that my final shape of scores is ([1024, 8, 6, 6])