# What's up with the gradient of torch.linalg.norm?

I’d expect the gradient of the L2 norm of a vector of ones to be 2. The gradient is as I expect when I roll my own norm function (`l2_norm` in mwe below). The gradient is not what I expect when I call `torch.linalg.norm`. Is there some implementation detail about `torch.linalg.norm` that I need to know in order to understand what it thinks the gradient should be?

expected grad of l-2 norm: tensor([2., 2., 2.])
grad of l-1 norm: tensor([3., 3., 3.])
grad of l-2 norm: tensor([0.5774, 0.5774, 0.5774])
grad of l-3 norm: tensor([0.4807, 0.4807, 0.4807])

``````import torch

def l2_norm(x):
"""Derivative of L2 norm is 2x."""

l2_norm(x).backward()

for p in [1, 2, 3]:
torch.linalg.vector_norm(x, ord=p).backward()
``````

Your definition of the L2-norm is incorrect, hence why you’re getting a different result to PyTorch. Here’s the correct definition,

``````import torch

def l2_norm(x):
"""Derivative of L2 norm is 2x."""
return x.pow(2).sum(-1).sqrt()           #correct

l2_norm(x).backward()

for p in [1, 2, 3]:
torch.linalg.vector_norm(x, ord=p).backward()
``````

This returns

``````expected grad of l-2 norm: tensor([0.5774, 0.5774, 0.5774])
grad of l-1 norm: tensor([1., 1., 1.])
grad of l-2 norm: tensor([0.5774, 0.5774, 0.5774])
grad of l-3 norm: tensor([0.4807, 0.4807, 0.4807])
``````

as expected!

TL;DR - you forgot the sqrt on your L2-norm definition. Also, there’s no need to detach the 2nd `x` either!

Thanks for finding that bug. I still have a question: Why is torch’s gradient different than `2x`? I’m assuming that e.g. this applies.

That’s the gradient of the square of the norm, not the gradient of the norm itself. So they’re different functions you’re differentiating, and hence have different answers.

Yup, just figured that out. I wasn’t looking at the definition closely enough. Thanks!

1 Like