Smooth-L1 loss equation differentiable?

The equation for Smooth-L1 loss is stated as:
image

To implement this equation in PyTorch, we need to use torch.where() which is non-differentiable.

    diff = torch.abs(pred - target)
    loss = torch.where(diff < beta, 0.5 * diff * diff / beta, diff - 0.5 * beta)

Why do we use torch.where() for Smooth-L1 loss if it is non-differentiable?

Hi,

you are correct that torch.where() would make it non-differentiable. However, there might be other ways to achieve this equation while using differentiable operations.

Here is just an example of how it might be done. The actual code is not open source, so I have no idea how they do it, but this is just to show that there are other ways.

target = torch.zeros(16)
pred = torch.arange(0, 2, step=0.125, requires_grad=True)
beta = 1

diff = torch.abs(pred - target)

higher = diff - 0.5 * beta
lower = 0.5 * diff * diff / beta

loss = torch.empty_like(higher)
loss[higher >= beta] = higher[higher >= beta]
loss[higher <  beta] = lower[higher < beta]

print(loss)
# Output
tensor([0.0000, 0.0078, 0.0312, 0.0703, 0.1250, 0.1953, 0.2812, 0.3828, 0.5000,
        0.6328, 0.7812, 0.9453, 1.0000, 1.1250, 1.2500, 1.3750],
       grad_fn=<IndexPutBackward0>)

Hope this helps :smile:

1 Like