Could not compute gradient for out = (out > 0).float()

I want to

out[out > 0] = 1
out[out <= 0] = -1

So I used out = (out > 0).float(), but the gradient could not be computed?

Reproduce code:

import torch

w = torch.FloatTensor([1.0, 2.0])
w.requires_grad = True
out = w * 2
out = (out > 0).float() # Exception occurs in this line of code

out.sum().backward()

print("print grad")
print("w has grad ", w.requires_grad)
print("w grad", w.grad)

Result error is:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

This is intended behavior since binary comparisons are hard operations and you cannot derive them.

However if you do

import torch

w = torch.FloatTensor([1.0, 2.0])
w.requires_grad_(True)
out = w * 2

out[out>0] = 1
out[out<0] = -1

out.sum().backward()

the gradients will flow correctly.

1 Like

As your function’s derivative is 0 everywhere (except for the input 0, where it isn’t smooth), you can also implement it as a function:

class StepFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()
    
    @staticmethod
    def backward(ctx, grad_output):
        return torch.zeros_like(grad_output)

w = torch.FloatTensor([1.0, 2.0])
w.requires_grad = True
step_function = StepFunction.apply
out = w * 2
out = step_function(out)

out.sum().backward()
1 Like