You don’t want to use reshape() to swap dimensions. This will mess up your tensors; more details here. You probably want:
x = x.transpose(1, 2)
You don’t want to use reshape() to swap dimensions. This will mess up your tensors; more details here. You probably want:
x = x.transpose(1, 2)