What's up with the gradient of torch.linalg.norm?

I’d expect the gradient of the L2 norm of a vector of ones to be 2. The gradient is as I expect when I roll my own norm function (l2_norm in mwe below). The gradient is not what I expect when I call torch.linalg.norm. Is there some implementation detail about torch.linalg.norm that I need to know in order to understand what it thinks the gradient should be?

expected grad of l-2 norm: tensor([2., 2., 2.])
grad of l-1 norm: tensor([3., 3., 3.])
grad of l-2 norm: tensor([0.5774, 0.5774, 0.5774])
grad of l-3 norm: tensor([0.4807, 0.4807, 0.4807])

import torch

x = torch.ones(3, requires_grad=True)

def l2_norm(x):
    """Derivative of L2 norm is 2x."""
    return torch.sum((x * x.detach()) ** 2)

l2_norm(x).backward()
print(f"expected grad of l-2 norm: {x.grad}")
x.grad.zero_()

for p in [1, 2, 3]:
    torch.linalg.vector_norm(x, ord=p).backward()
    print(f"grad of l-{p} norm: {x.grad}")
    x.grad.zero_()

Your definition of the L2-norm is incorrect, hence why you’re getting a different result to PyTorch. Here’s the correct definition,

import torch

x = torch.ones(3, requires_grad=True)

def l2_norm(x):
    """Derivative of L2 norm is 2x."""
    #return torch.sum((x * x.detach()) ** 2) #wrong
    return x.pow(2).sum(-1).sqrt()           #correct

l2_norm(x).backward()
print(f"expected grad of l-2 norm: {x.grad}")
x.grad.zero_()

for p in [1, 2, 3]:
    torch.linalg.vector_norm(x, ord=p).backward()
    print(f"grad of l-{p} norm: {x.grad}")
    x.grad.zero_()

This returns

expected grad of l-2 norm: tensor([0.5774, 0.5774, 0.5774])
grad of l-1 norm: tensor([1., 1., 1.])
grad of l-2 norm: tensor([0.5774, 0.5774, 0.5774])
grad of l-3 norm: tensor([0.4807, 0.4807, 0.4807])

as expected! :slight_smile:

TL;DR - you forgot the sqrt on your L2-norm definition. Also, there’s no need to detach the 2nd x either!

Thanks for finding that bug. I still have a question: Why is torch’s gradient different than 2x? I’m assuming that e.g. this applies.

That’s the gradient of the square of the norm, not the gradient of the norm itself. So they’re different functions you’re differentiating, and hence have different answers.

Yup, just figured that out. I wasn’t looking at the definition closely enough. Thanks!

1 Like