Output shape of torch.count_nonzero

After the execution of code mentioned below,The output shape of c1 variable after using torch.count_nonzero is torch.Size([]).

1.) When “vb.backward(torch.ones(va.size()))” is called, it gives following error as c1 shape is torch.Size([])

RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([1, 1, 256, 256]) and output[0] has a shape of torch.Size([]).

2.) If I replace “c1=torch.count_nonzero(input)” with “c1 = torch.tensor([1.1], dtype=torch.float)” , then it gives the following error.

RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([1, 1, 256, 256]) and output[0] has a shape of torch.Size([1]).

Code is as below.

class fun(Function):
@staticmethod
def forward(cxt,input):
cxt.save_for_backward(input)
c1=torch.count_nonzero(input)
print(c1.shape)
return c1
@staticmethod
def backward(cxt, grad_output):
input, = cxt.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input
fun1 = fun.apply
if name == “main”:
from torch.autograd import Variable
a=torch.rand(1,1,256,256)
va = Variable(a, requires_grad=True)
vb = fun1(va)
vb.backward(torch.ones(va.size()))

There is a correction here.

Its if __name == __main