How to implement focal loss in pytorch?

I guess there is something wrong in the original code which breaks the computation graph and makes loss not decrease. I doubt it is this line:

    pt = Variable(pred_prob_oh.data.gather(1, target.data.view(-1, 1)), requires_grad=True)

Is torch.gather support autograd? Is there anyway to implement this?
Many thanks!

1 Like