Loss is always around 5.6 and doesn't decrease, Accuracy is at 0.64% - 0.84%

You don’t want to use reshape() to swap dimensions. This will mess up your tensors; more details here. You probably want:

x = x.transpose(1, 2)