These are some lines from my loss function. output
is the output of a multiclass classification network.
bin_count=torch.bincount(torch.where(output>.1)[0], minlength=output.shape[0])
dr_output = (bin_count == 1) & (torch.argmax(output, dim=1)==labels)
I want dr_output.sum()
to be part of my loss function. But there are many limitations in my implementation. Some functions are non-differentiable in pytorch, and also dr_output
may be zero which is also not allowed if I only use dr_output
as my loss. Can anyone please suggest to me a way around these problems?