Someone from the forums suggested me to remove the Softmax and mse loss and try this instead:
def softXEnt(input, target):
logprobs = torch.nn.functional.log_softmax (input, dim = 1)
return -(target * logprobs).mean()
In this case, the loss decreases but the model learns poorly as i can see from the predictions