I have two tensor, a and b, with 4 dimensions each of the same size and I want to write the following operation using einsum:
output = torch.matmul(a, b.transpose(-2,-1))
This is what I have so far:
output = torch.einsum("ijkl,jmnl->imkn", [q,k])
But this is not correct, how do I write 4d tensor multiplication using einsum?