As per below code the, the backward function is called after run. But once i uncomment the line “output=torch.count_nonzero(output)” it shows the error “RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn”. It is necessary to use “torch.count_nonzero” function and wants to pass the “grad_input” 1 as it is non differentiable.

from torch.autograd import Function
import torch
import torch.nn as nn
class custom_fun(Function):
@staticmethod
def forward(cxt, input):
cxt.save_for_backward(input)
output=torch.add(input,2)
print("calling forward",output)
# output=torch.count_nonzero(output)
return output
@staticmethod
def backward(cxt, grad_output):
input, = cxt.saved_tensors
print("Calling backward function ")
grad_input = grad_output.clone()
grad_input = torch.ones(1)
return grad_input
class my_module(nn.Module):
def __init__(self):
super(my_module,self).__init__()
self.fc1 = nn.Linear(1,1)
self.my_act_fn=custom_fun()
def forward(self, x):
out1=self.fc1(x)
out=self.my_act_fn.apply(out1)
return out
mod = my_module()
x=torch.tensor([1.1])
target=torch.tensor([2.5])
out_m=mod(x)
criterion=torch.nn.L1Loss()
loss=criterion(out_m,target)
loss.backward()

The problem is that the output of count_nonzero is a Tensor of type int64. So it cannot require gradients (only continuous types can).
So you want to add a output = output.float() after the count_nonzero to make sure the output of your custom Function can require gradients.

Is @cbd’s custom function a sound approach? As seems to be mentioned by OP, torch.count_nonzero is non-differentiable. If I understand correctly, is this essentially ignoring the gradient of torch.count_nonzero?

Rather than use torch.count_nonzero, which is non-differentiable, an alternative is to approximate this via a “differentiable relaxation using a sum of narrow Gaussian basis functions”. i.e.

As the value for the i-th component (x[i] or x_{i}) approaches zero, the Gaussian basis function evaluated at x tends towards 1:

lim_{x→0} f(x_{i}) = 1

Likewise, as the value of the i-th component deviates from zero, the Gaussian basis function evaluated at x tends towards 0:

lim_{x→∞} f(x_{i}) = 0

So if you have three components with values close to zero and the rest not close to zero, then the sum of all the narrow Gaussians is approximately 3, and the number of non-zero components is len(x) minus this sum.