I am getting some weird behavior when using torch.norm
with dim=(1,2)
in my loss computation:
m = nn.Linear(3, 9)
nn.init.constant_(m.weight, 0)
nn.init.eye_(m.bias.view(3, 3))
x = torch.rand((2, 3))
out = m(x).view((2, 3, 3))
I = torch.eye(3).unsqueeze(0)
diff = out - I
loss = torch.mean(torch.norm(diff, dim=(1, 2)))
loss.backward()
print(m.bias.grad)
gives tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan])
.
However if instead I use the equivalent
loss = torch.mean(torch.norm(torch.norm(diff, dim=1), dim=1))
it gives the correct result:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.])
.
What could be the reason for this?