My model is throwing NaNs intermittently. From debugging, i found on every occasion, dropout was the layer whose output was NaN first. Why is dropout outputing NaNs?
Model is being trained in mixed precission.
I added this hook
def nan_hook(self, inp, output):
if not isinstance(output, tuple):
outputs = [output]
else:
outputs = output
for i, out in enumerate(outputs):
nan_mask = torch.isnan(out)
if nan_mask.any():
print("Hook: Nan occured In", self.__class__.__name__)
for submodule in self.model.modules():
submodule.register_forward_hook(nan_hook)
Output
70_driver_log_9.txt-SystemLog: 2020-02-22 17:13:45,253:DEBUG : transformers_pretraining.trainer.apexDDP : 159 : ***** Training step 172 *****
70_driver_log_9.txt-SystemLog: 2020-02-22 17:13:45,253:DEBUG : transformers_pretraining.utils : 47 : Inside <function Singleprocess._forward at 0x7f773402fb70>
70_driver_log_9.txt:Hook: Nan occured In Dropout
70_driver_log_9.txt:Hook: Nan occured In BertSelfAttention
70_driver_log_9.txt:Hook: Nan occured In Linear
70_driver_log_9.txt:Hook: Nan occured In Dropout
70_driver_log_9.txt:Hook: Nan occured In LayerNorm
70_driver_log_9.txt:Hook: Nan occured In BertSelfOutput
70_driver_log_9.txt:Hook: Nan occured In BertAttention
Model architecture:
BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=1024, out_features=1024, bias=True)
(key): Linear(in_features=1024, out_features=1024, bias=True)
(value): Linear(in_features=1024, out_features=1024, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=1024, out_features=1024, bias=True)
(LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=1024, out_features=4096, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=4096, out_features=1024, bias=True)
(LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)