MSE Loss becomes NaN after a few iterations

I am training a deep model with an LSTM and GNNs. It is a regression problem so the loss is MSE Loss. I am seeing that the loss becomes NaN after a few iterations. I turned on torch.autograd.set_detect_anomaly, and here is what I get:

The code in question is:


order = torch.argsort(lengths, descending=True)
order_rev = torch.argsort(order)
x, lengths = x[order], lengths[order]

    # feed-forward bidirectional RNN
    self.rnn.flatten_parameters()
    total_length = x.shape[1]
    h = self.embed(x)
    h = pack_padded_sequence(h, lengths, batch_first=True)
    h, _ = self.rnn(h)
    h = pad_packed_sequence(h, batch_first=True, total_length=total_length)[0][:, 1:-1][order_rev]


I inserted forward hooks, and checked the parameters after every training loop like so:

def nan_hook(self, inp, output):
if type(output) == torch.nn.utils.rnn.PackedSequence:
outputs = [output.data]
elif not isinstance(output, tuple):
if type(output[0]) == torch.nn.utils.rnn.PackedSequence:
outputs = [output.data]
else:
outputs = [output]
else:
if type(output[0]) == torch.nn.utils.rnn.PackedSequence:
outputs = [output[0].data]
else:
outputs = output

        for i, out in enumerate(outputs):
            nan_mask = torch.isnan(out)
            if nan_mask.any():
                print("In", self.__class__.__name__)
                raise RuntimeError(f"Found NAN in output {i} at indices: ", nan_mask.nonzero(), "where:", out[nan_mask.nonzero()[:, 0].unique(sorted=True)])
            inf_mask = torch.isinf(out)
            if inf_mask.any():
                print("In", self.__class__.__name__)
                raise RuntimeError(f"Found INF in output {i} at indices: ", inf_mask.nonzero(), "where:", out[inf_mask.nonzero()[:, 0].unique(sorted=True)])

for submodule in model.modules():
submodule.register_forward_hook(nan_hook)

Checking for parameters after every training loop:

    for model in self.models_dict["model"]:
        for param in model.parameters():
            if not torch.isfinite(param.data).all():
                print("param data:", torch.isfinite(param.data).all())

None of these checks fire indicating that there are no NaNs in the data or in any of the parameters.

My machine configuration is:
NVIDIA-SMI 460.27.04
Driver: 460.27.04
CUDA 11.2
GPU: V100, 16GB

I am using python 3.8.5, and pytorch 1.6.0 and cudatoolkit 10.2.89

This problem started coming after IT upgraded the machine. I am trying to find out from IT what upgrade they did on the machine.

I tried upgrading to pytorch1.7 and cudatoolkit11.0, but that combination seems to be unstable and has its own share of problems.

I have searched on the forums, and have tried the usual techniques (reducing learning rate, adding L2 regularization etc), but those just move the problem around to a later point. Gradient clipping is also turned on. The problem never really goes away.

Any suggestions to debug will be very welcome.

Thanks,
Apurva