Hi,
As far as I understand round function is not differentiable. So what is it doing in backward pass of torch.round()?
Hi,
As far as I understand round function is not differentiable. So what is it doing in backward pass of torch.round()?
It’ll give you 0, as this snippet shows:
a = torch.tensor(1.5)
a.requires_grad_()
b = torch.round(a)
b.backward()
a.grad
I tried this…it looks like torch.round is not differentiable. But torch.clamp is differentiable. Is this clear from the documentation or how do we know which functions are differentiable?
Both functions are differentiable almost everywhere. Think about what the functions look like.
round() is a step function so it has derivative zero almost everywhere. Although it’s differentiable (almost everywhere), it’s not useful for learning because of the zero gradient.
clamp() is linear, with slope 1, inside (min, max) and flat outside of the range. This means the derivative is 1 inside (min, max) and zero outside. It can be useful for learning as long as enough of the input is inside the range.
Hi,
I was just wondering if you know a way that you can round the output of a network and still allow for successful learning?
Thanks
One possible approach is to implement your own round function with customized backpropagation.
E.g.
class my_round_func(InplaceFunction):
@staticmethod
def forward(ctx, input):
ctx.input = input
return torch.round(input)
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
return grad_input
how do i use it?replace InplaceFunction
to nn.Module
?
Should be torch.autograd.function.InplaceFunction
If you need a gradient to be passed, you can make use of a straight through estimator
o = x + x.round().detach() - x.detach()
During forward pass, the x and -x will cancel.
During backward pass, x.round() and -x will be ignored, gradient flow will continue from x, instead of x.round().