Gradient flow through torch.count_nonzero

As per below code the, the backward function is called after run. But once i uncomment the line “output=torch.count_nonzero(output)” it shows the error “RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn”. It is necessary to use “torch.count_nonzero” function and wants to pass the “grad_input” 1 as it is non differentiable.

from torch.autograd import Function
import torch
import torch.nn as nn

class custom_fun(Function):

    def forward(cxt, input):
        print("calling forward",output)
        # output=torch.count_nonzero(output)
        return output

    def backward(cxt, grad_output):
        input, = cxt.saved_tensors
        print("Calling backward function ")
        grad_input = grad_output.clone()
        grad_input = torch.ones(1)
        return grad_input
class my_module(nn.Module):

 def __init__(self):
      self.fc1 = nn.Linear(1,1)
  def forward(self, x):
      return out

mod = my_module()


The problem is that the output of count_nonzero is a Tensor of type int64. So it cannot require gradients (only continuous types can).
So you want to add a output = output.float() after the count_nonzero to make sure the output of your custom Function can require gradients.

Thanks. Now gradient flows.

Is @cbd’s custom function a sound approach? As seems to be mentioned by OP, torch.count_nonzero is non-differentiable. If I understand correctly, is this essentially ignoring the gradient of torch.count_nonzero?

Rather than use torch.count_nonzero, which is non-differentiable, an alternative is to approximate this via a “differentiable relaxation using a sum of narrow Gaussian basis functions”. i.e.

def narrow_gaussian(x, ell):
    return torch.exp(-0.5 * (x / ell) ** 2)

def approx_count_nonzero(x, ell=1e-3):
    # Approximation of || x ||_0
    return len(x) - narrow_gaussian(x, ell).sum(dim=-1)

As the value for the i-th component (x[i] or xi) approaches zero, the Gaussian basis function evaluated at x tends towards 1:

limx→0 f(xi) = 1

Likewise, as the value of the i-th component deviates from zero, the Gaussian basis function evaluated at x tends towards 0:

limx→∞ f(xi) = 0

So if you have three components with values close to zero and the rest not close to zero, then the sum of all the narrow Gaussians is approximately 3, and the number of non-zero components is len(x) minus this sum.

modified from Scipy non-linear inequality constraints.ipynb as described in [FEATURE REQUEST] modify Ax API to allow for callable that evaluates a constraint and is passed to the optimizer · Issue #769 · facebook/Ax · GitHub.

Credit: David Eriksson, Senior Research Scientist at Meta