Torch.round() gradient

Hi,

As far as I understand round function is not differentiable. So what is it doing in backward pass of torch.round()?

It’ll give you 0, as this snippet shows:

a = torch.tensor(1.5)
a.requires_grad_()
b = torch.round(a)
b.backward()
a.grad
1 Like

I tried this…it looks like torch.round is not differentiable. But torch.clamp is differentiable. Is this clear from the documentation or how do we know which functions are differentiable?

Both functions are differentiable almost everywhere. Think about what the functions look like.

round() is a step function so it has derivative zero almost everywhere. Although it’s differentiable (almost everywhere), it’s not useful for learning because of the zero gradient.

clamp() is linear, with slope 1, inside (min, max) and flat outside of the range. This means the derivative is 1 inside (min, max) and zero outside. It can be useful for learning as long as enough of the input is inside the range.

14 Likes

Hi,

I was just wondering if you know a way that you can round the output of a network and still allow for successful learning?

Thanks

One possible approach is to implement your own round function with customized backpropagation.

E.g.

class my_round_func(InplaceFunction):
    @staticmethod
    def forward(ctx, input):
        ctx.input = input
        return torch.round(input)

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return grad_input
1 Like

how do i use it?replace InplaceFunction to nn.Module?

Should be torch.autograd.function.InplaceFunction

If you need a gradient to be passed, you can make use of a straight through estimator

o = x + x.round().detach() - x.detach()

During forward pass, the x and -x will cancel.
During backward pass, x.round() and -x will be ignored, gradient flow will continue from x, instead of x.round().

2 Likes