Custom loss function which counts the number of true element in a boolean array

These are some lines from my loss function. output is the output of a multiclass classification network.

bin_count=torch.bincount(torch.where(output>.1)[0], minlength=output.shape[0])

dr_output = (bin_count == 1) & (torch.argmax(output, dim=1)==labels)

I want dr_output.sum() to be part of my loss function. But there are many limitations in my implementation. Some functions are non-differentiable in pytorch, and also dr_output may be zero which is also not allowed if I only use dr_output as my loss. Can anyone please suggest to me a way around these problems?

You could use soft argmax instead of argmax to get around the non-differentiablity of argmax. Other non-differentible functions that you are using may also have soft analogues.