Torch.round() gradient

Hi,

As far as I understand round function is not differentiable. So what is it doing in backward pass of torch.round()?

It’ll give you 0, as this snippet shows:

a = torch.tensor(1.5)
a.requires_grad_()
b = torch.round(a)
b.backward()
a.grad

I tried this…it looks like torch.round is not differentiable. But torch.clamp is differentiable. Is this clear from the documentation or how do we know which functions are differentiable?

Both functions are differentiable almost everywhere. Think about what the functions look like.

round() is a step function so it has derivative zero almost everywhere. Although it’s differentiable (almost everywhere), it’s not useful for learning because of the zero gradient.

clamp() is linear, with slope 1, inside (min, max) and flat outside of the range. This means the derivative is 1 inside (min, max) and zero outside. It can be useful for learning as long as enough of the input is inside the range.

7 Likes

Hi,

I was just wondering if you know a way that you can round the output of a network and still allow for successful learning?

Thanks