You can only call torch.nonzero() on a simple tensor, not a variable. It makes sens: I doubt that the counting of non-zero element would be differentiable.
but you have sum( abs( x/ (abs(x) + epsilon) )), that approximates the number of zero, and is differentiable.
Actually I was trying to take the average of all non-zero elements in a 1-d tensor which is actually the total loss for my model. I am doing the following.
loss = losses.sum() / torch.nonzero(losses.data).size(0)
It is working as expected and also backpropagation is not causing any problem, so I am assuming taking average of non-zero elements is differentiable. Do you have any thought about it?
Yes, of course your total loss L is (piecewise) differentiable. It can be more formally defined as:
L = sum( Li ) / sum( 1{Li ≠ 0} ),
where 1{c} is the indicator function (which is 1 when c is true and 0 otherwise).
Clearly, the function f ( Li ) = 1{Li ≠ 0} has derivative equal to 0 everywhere, except at Li = 0, where the derivative does not exist. In practice, you may assume that it is 0 everywhere.
Your biggest concern should be ensuring that you have no problem in using the tensor losses.data instead of the variable losses. This is because PyTorch will see torch.nonzero(losses.data).size(0) as a constant and not as a function of losses. Luckily, you may easily check that the derivative of L w.r.t. each of the losses Lj is the same whether you consider sum( 1{Li ≠ 0} ) as a function of Lj or not:
@wasiahmad, what you are minimizing is just losses.sum(), and your gradient descent steps are multiplied by a weight that depends on the number of non-zeros element, different at each iteration. But nothing guarantees that it will minimize sum(x)/non-zeros(x) for all x, which is (I think) what you want to do.
noticed this too! seems nonzero() is super slow. It also varies a lot every time its called. We had a situation where the first time its called it runs very fast and then subsequent calls run 10x slower.
Have the same problem. For the same function with the exact input, first time runs 2.1s and second time 0.002s. Don’t know exactly what causes this problem… experience it in Pytorch 1.0