I guess there is something wrong in the original code which breaks the computation graph and makes loss not decrease. I doubt it is this line:
pt = Variable(pred_prob_oh.data.gather(1, target.data.view(-1, 1)), requires_grad=True)
Is torch.gather support autograd? Is there anyway to implement this?
Many thanks!