How to add the number of non zero elements to the loss function

Hi everyone,

I try to implement an idea I found in a paper where an encoder is trained, and the loss function is defined as: MSE(y*, y) + #of non 0 values in one of the layers.
i.e. the loss function is composed of both the prediction MSE and a count of the non 0.0 values of the encoding layer.

My problem is that I can’t implement the non zero value counting. I have tried:

  1. adding (encoded_x != 0.0).sum() to the loss function. Where encoded_x is the values of the middle layer. Apparently this is not a tensor so it has not effect (?).
  2. approximation - adding encoded_x.sum() or even torch.pow(encoded_x, 1/8).sum(). Here I tried to penalize by the value itself. The result is that the encoding layer has really small values.

Any ideas how this might be implemented? I understand the derivative for the sum is not always defined, but there must be a way.


1 Like

A penalty on the number of non-zero values tries to enforce sparsity. This is usually achieved by computing the L1 loss. So to add this loss you can compute torch.nn.L1Loss(encoded_x, torch.zeros_like(encoded_x)) and add it to your total loss function.

Hi, and thank you very much for your response. I forgot to mention this but I have tried something similar (Although I used torch.zeros rather then zeros_like). My impression is that L1 is not exactly what I need as it penalizes by the value (and not a binary penalty), so the result is simply small values. Did I get something wrong?