detect_anomaly yields RuntimeError: Function 'MseLossBackward' returned nan values in its 0th output. which as I mentioned in my first post isn’t very helpful in this case since the NaNs are already present in the input to the loss function.
There’s usually exactly one NaN in the first batch - interestingly the exact index of where in the batch the NaN occurs (or whether it occurs at all) seems to vary despite my use of torch.manual_seed. Following your advice, I used the following hook for raising a RuntimeError if a NaN is encountered and added it to my models’ submodules:
def nan_hook(self, inp, output):
if not isinstance(output, tuple):
outputs = [output]
else:
outputs = output
for i, out in enumerate(outputs):
nan_mask = torch.isnan(out)
if nan_mask.any():
print("In", self.__class__.__name__)
raise RuntimeError(f"Found NAN in output {i} at indices: ", nan_mask.nonzero(), "where:", out[nan_mask.nonzero()[:, 0].unique(sorted=True)])
for submodule in model.modules():
submodule.register_forward_hook(nan_hook)
From what I can tell the NaNs are generated from a call to torch.nn.functional.gumbel_softmax. I tried to check for the offending input as follows:
gumbel_one_hot = functional.gumbel_softmax(raw_attention, hard=True, dim=1)
nan_mask = torch.isnan(gumbel_one_hot)
if nan_mask.any():
print(nan_mask.nonzero())
indices = nan_mask.nonzero()[:, 0].unique(sorted=True)
print("Input:", raw_attention[indices])
print("Output":, gumbel_one_hot[indices])
raise RuntimeError("NaN encountered in gumbel softmax")
However, it outputs the following which shouldn’t produce NaNs and I can’t reproduce it either by using the same seed on GPU in a separate script:
Input: tensor([[-0.2471, -0.2118, -0.1458, -0.1783, -0.1228, -0.1062, -0.0502, 0.0112,
0.0097, -0.2189, -inf, -inf, 0.2214, 0.2723, 0.0808, 0.1290,
0.1007, 0.4129, 0.4808, 0.3924, 0.2889, 0.4916, -inf, -inf]],
device='cuda:0', grad_fn=<IndexBackward>)
Output: tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]],
device='cuda:0', grad_fn=<IndexBackward>)
Is this issue possibly related to https://github.com/pytorch/pytorch/issues/22442?