# Torch.norm with dim=(1,2) gives nan grads

I am getting some weird behavior when using torch.norm with dim=(1,2) in my loss computation:

m = nn.Linear(3, 9)
nn.init.constant_(m.weight, 0)
nn.init.eye_(m.bias.view(3, 3))

x = torch.rand((2, 3))
out = m(x).view((2, 3, 3))

I = torch.eye(3).unsqueeze(0)
diff = out - I
loss = torch.mean(torch.norm(diff, dim=(1, 2)))
loss.backward()

gives tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan]).

However if instead I use the equivalent
loss = torch.mean(torch.norm(torch.norm(diff, dim=1), dim=1))
it gives the correct result:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.]).

What could be the reason for this?

I think this could be because the default norm (Frobenius) has a sqrt that causes issues when the input is near zero e.g., consider the following snippets:

\$ cat example.py
import torch

out = torch.mean(torch.sqrt(b**2))
print(out)
out.backward()
\$ python3 example.py
tensor([[[nan, nan],
[nan, nan]],

[[nan, nan],
[nan, nan]]])
\$ cat example2.py
import torch

out = torch.mean(torch.sqrt(b**2 + 1e-12))
print(out)
out.backward()
\$ python3 example2.py
tensor([[[0., 0.],
[0., 0.]],

[[0., 0.],
[0., 0.]]])
\$ cat example0.py
import torch
from torch import nn
m = nn.Linear(3, 9)
nn.init.constant_(m.weight, 0)
nn.init.eye_(m.bias.view(3, 3))

x = torch.rand((2, 3))
out = m(x).view((2, 3, 3))

I = torch.eye(3).unsqueeze(0)
diff = out - I
loss = torch.mean(torch.norm(diff + 1e-12, dim=(1, 2)))
loss.backward()