Torch.norm with dim=(1,2) gives nan grads

I am getting some weird behavior when using torch.norm with dim=(1,2) in my loss computation:

  m = nn.Linear(3, 9)
  nn.init.constant_(m.weight, 0)
  nn.init.eye_(m.bias.view(3, 3))

  x = torch.rand((2, 3))
  out = m(x).view((2, 3, 3))

  I = torch.eye(3).unsqueeze(0)
  diff = out - I
  loss = torch.mean(torch.norm(diff, dim=(1, 2)))
  loss.backward()

  print(m.bias.grad)

gives tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan]).

However if instead I use the equivalent
loss = torch.mean(torch.norm(torch.norm(diff, dim=1), dim=1))
it gives the correct result:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.]).

What could be the reason for this?

I think this could be because the default norm (Frobenius) has a sqrt that causes issues when the input is near zero e.g., consider the following snippets:

$ cat example.py
import torch

b = torch.zeros(2,2,2, requires_grad=True)
out = torch.mean(torch.sqrt(b**2))
print(out)
out.backward()
print(b.grad)
$ python3 example.py
tensor(0., grad_fn=<MeanBackward0>)
tensor([[[nan, nan],
         [nan, nan]],

        [[nan, nan],
         [nan, nan]]])
$ cat example2.py
import torch

b = torch.zeros(2,2,2, requires_grad=True)
out = torch.mean(torch.sqrt(b**2 + 1e-12))
print(out)
out.backward()
print(b.grad)
$ python3 example2.py
tensor(1.0000e-06, grad_fn=<MeanBackward0>)
tensor([[[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]]])
$ cat example0.py
import torch
from torch import nn
m = nn.Linear(3, 9)
nn.init.constant_(m.weight, 0)
nn.init.eye_(m.bias.view(3, 3))

x = torch.rand((2, 3))
out = m(x).view((2, 3, 3))

I = torch.eye(3).unsqueeze(0)
diff = out - I
loss = torch.mean(torch.norm(diff + 1e-12, dim=(1, 2)))
loss.backward()

print(m.bias.grad)
$ python3 example0.py
tensor([0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333])

Indeed it seems that sqrt is the reason. However it’s strange that using torch.mean(torch.norm(torch.norm(diff, dim=1), dim=1)) works fine.