Weights becomes nan after first iteration

after first Trainer iterations, model weights become Nan. and I can’t find why …

here is my encoder model:

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()

        self.conv = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            padding=kernel_size//2
        )

        self.norm = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()
        self.drop = nn.Dropout()

    def forward(self, x):
        return self.drop(
            self.relu(self.norm(self.conv(x))))

and decoder model:

class ConvDecoder(nn.Module):
    def __init__(self, in_channels, vocab_size):
        super().__init__()
        self.decoder = nn.Conv1d(
            in_channels=in_channels, out_channels=vocab_size, kernel_size=1)

    def forward(self, x):
        return self.decoder(x)

These are simplified version of my model which still make same error

and finally my complete model:

class SpeechRecognitionModel(ModulePT):
    def __init__(self, vocab_size):
        super().__init__()
        self.acoustic_model = ConvBlock(64, 128, 33)
        self.decoder_model = ConvDecoder(
            in_channels=128, vocab_size=vocab_size)
        self.loss = nn.CTCLoss(blank=0)

    def forward(self, x):
        x = self.acoustic_model(x)
        return self.decoder_model(x)

    def configure_optimizers(self):
        return optim.Adamax(self.parameters(), 5e-4)

    def training_step(self, train_batch, *args):
        inputs, input_lengths, outputs, output_lengths = train_batch
        logits = self.forward(inputs)
        logits = logits.permute((-1, 0, 1))
        probs = nn.functional.log_softmax(logits, dim=-1)
        loss = self.loss(probs, outputs, input_lengths, output_lengths)
        self.log("loss", loss)
        return loss
        
asr_model = SpeechRecognitionModel(vocab_size=tokenizer.vocab_size)

where ModulePT is inherite from pl.LightningModule and trained using pl.Trainer on cpu (pl refered to pytorch_lightning).

  • after first train iteration, my loss function returned nan and then I checked model parameters and this is the result:
for param in asr_model.acoustic_model.parameters():
    print(param)

output:

tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],
         ...

I double check my dataloader and dataset which is fine-tuned version of LibriSpeech and became sure non of input sample have any nan value, I ran torch.isnan(inputs).any() for all batch of samples.

what might cause this problem ?

  • update:
    I replace pl.Trainer with raw code and here is my trainer code:
model = SpeechRecognitionModel(vocab_size=tokenizer.vocab_size)
optimizer = optim.AdamW(model.parameters(), 5e-4)
NAN = lambda x: torch.isnan(x).any()
loss = nn.CTCLoss(zero_infinity=True)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=5e-4, 
                                            steps_per_epoch=int(len(speech_recognition_dataloader)),
                                            epochs=1,
                                            anneal_strategy='linear')
    

for batch_idx, batch in enumerate(speech_recognition_dataloader):
    mfccs, mfccs_lengths, labels, label_lengths = batch
    
    optimizer.zero_grad()

    output = model(mfccs)
    output = output.permute((-1, 0, 1))
    probs = nn.functional.log_softmax(output, dim=-1)
    _loss = loss(probs, labels, mfccs_lengths, label_lengths)
    _loss.backward()
    optimizer.step()
    scheduler.step()

    print(f"iter: {batch_idx}, input_data is nan: {NAN(mfccs)}, loss: {_loss}, output is nan: {NAN(output)}")

here is what is get:

iter: 0, input_data is nan: False, loss: 19.476131439208984, output is nan: False
iter: 1, input_data is nan: False, loss: nan, output is nan: True
iter: 2, input_data is nan: False, loss: nan, output is nan: True
iter: 3, input_data is nan: False, loss: nan, output is nan: True
iter: 4, input_data is nan: False, loss: nan, output is nan: True

so right now, it’s clear loss computed for first iteration but weights became nan for next iteration, is this optimizer problem?

Thanks

I followed suggestions from this Topic

my trainer:

model = torch.nn.Conv1d(64, tokenizer.vocab_size, kernel_size=33, padding=33//2)
optimizer = optim.AdamW(model.parameters(), 5e-4)
loss = nn.CTCLoss(zero_infinity=True)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=5e-4, 
                                            steps_per_epoch=int(len(speech_recognition_dataloader)),
                                            epochs=1,
                                            anneal_strategy='linear')
    
torch.autograd.set_detect_anomaly(True)
for batch_idx, batch in enumerate(speech_recognition_dataloader):
    mfccs, mfccs_lengths, labels, label_lengths = batch
    
    optimizer.zero_grad()

    output = model(mfccs)
    output = output.permute((-1, 0, 1))
    print(output.size())
    probs = nn.functional.log_softmax(output, dim=-1)
    _loss = loss(probs, labels, mfccs_lengths, label_lengths)
    _loss.backward()
    optimizer.step()
    scheduler.step()

    print(f"iter: {batch_idx}, input_data is nan: {NAN(mfccs)}, loss: {_loss}, output is nan: {NAN(output)}")

my output which comes out of nn.Conv1d model has shape of (sequence, batch, features=28) where 28 is my vocab size.

I used torch.autograd.set_detect_anomaly(True) and here is output:

torch.Size([1573, 10, 28])
/home/arsham/.local/lib/python3.8/site-packages/torch/autograd/__init__.py:147: UserWarning: Error detected in LogSoftmaxBackward. Traceback of forward call that caused the error:
  File "/home/arsham/.vscode/extensions/ms-toolsai.jupyter-2021.6.999662501/pythonFiles/vscode_datascience_helpers/kernel_prewarm_starter.py", line 31, in <module>
    runpy.run_module(module, run_name="__main__", alter_sys=False)
  File "/usr/lib/python3.8/runpy.py", line 210, in run_module
    return _run_code(code, {}, init_globals, run_name, mod_spec)
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/arsham/.local/lib/python3.8/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/home/arsham/.local/lib/python3.8/site-packages/traitlets/config/application.py", line 845, in launch_instance
    app.start()
  File "/home/arsham/.local/lib/python3.8/site-packages/ipykernel/kernelapp.py", line 668, in start
    self.io_loop.start()
  File "/home/arsham/.local/lib/python3.8/site-packages/tornado/platform/asyncio.py", line 199, in start
    self.asyncio_loop.run_forever()
  File "/usr/lib/python3.8/asyncio/base_events.py", line 570, in run_forever
    self._run_once()
  File "/usr/lib/python3.8/asyncio/base_events.py", line 1859, in _run_once
    handle._run()
  File "/usr/lib/python3.8/asyncio/events.py", line 81, in _run
    self._context.run(self._callback, *self._args)
  File "/home/arsham/.local/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 456, in dispatch_queue
    await self.process_one()
  File "/home/arsham/.local/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 445, in process_one
    await dispatch(*args)
  File "/home/arsham/.local/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 352, in dispatch_shell
    await result
  File "/home/arsham/.local/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 647, in execute_request
    reply_content = await reply_content
  File "/home/arsham/.local/lib/python3.8/site-packages/ipykernel/ipkernel.py", line 335, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/home/arsham/.local/lib/python3.8/site-packages/ipykernel/zmqshell.py", line 532, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/home/arsham/.local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2898, in run_cell
    result = self._run_cell(
  File "/home/arsham/.local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2944, in _run_cell
    return runner(coro)
  File "/home/arsham/.local/lib/python3.8/site-packages/IPython/core/async_helpers.py", line 68, in _pseudo_sync_runner
    coro.send(None)
  File "/home/arsham/.local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3169, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/home/arsham/.local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3361, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
  File "/home/arsham/.local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3441, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_13510/3759225913.py", line 19, in <module>
    probs = nn.functional.log_softmax(output, dim=0)
  File "/home/arsham/.local/lib/python3.8/site-packages/torch/nn/functional.py", line 1768, in log_softmax
    ret = input.log_softmax(dim)
 (Triggered internally at  /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:104.)
  Variable._execution_engine.run_backward(
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_13510/3759225913.py in <module>
     19     probs = nn.functional.log_softmax(output, dim=0)
     20     _loss = loss(probs, labels, mfccs_lengths, label_lengths)
---> 21     _loss.backward()
     22     optimizer.step()
     23     scheduler.step()

~/.local/lib/python3.8/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    253                 create_graph=create_graph,
    254                 inputs=inputs)
--> 255         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    256 
    257     def register_hook(self, hook):

~/.local/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    145         retain_graph = create_graph
    146 
--> 147     Variable._execution_engine.run_backward(
    148         tensors, grad_tensors_, retain_graph, create_graph, inputs,
    149         allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag

RuntimeError: Function 'LogSoftmaxBackward' returned nan values in its 0th output.

this error comes from this part: probs = nn.functional.log_softmax(output, dim=-1),
where output.size() == (1573, 10, 28).

then I removed the log_softmax part and calculate the loss directly with model output like so:

    output = model(mfccs)
    output = output.permute((-1, 0, 1))
    probs = output
    _loss = loss(probs, labels, mfccs_lengths, label_lengths)

and I get anomaly from MkldnnConvolutionBackward:

/home/arsham/.local/lib/python3.8/site-packages/torch/autograd/__init__.py:147: UserWarning: Error detected in MkldnnConvolutionBackward. Traceback of forward call that caused the error:
  File "/home/arsham/.vscode/extensions/ms-toolsai.jupyter-2021.6.999662501/pythonFiles/vscode_datascience_helpers/kernel_prewarm_starter.py", line 31, in <module>
    runpy.run_module(module, run_name="__main__", alter_sys=False)
  File "/usr/lib/python3.8/runpy.py", line 210, in run_module
    return _run_code(code, {}, init_globals, run_name, mod_spec)
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/arsham/.local/lib/python3.8/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/home/arsham/.local/lib/python3.8/site-packages/traitlets/config/application.py", line 845, in launch_instance
    app.start()
  File "/home/arsham/.local/lib/python3.8/site-packages/ipykernel/kernelapp.py", line 668, in start
    self.io_loop.start()
  File "/home/arsham/.local/lib/python3.8/site-packages/tornado/platform/asyncio.py", line 199, in start
    self.asyncio_loop.run_forever()
  File "/usr/lib/python3.8/asyncio/base_events.py", line 570, in run_forever
    self._run_once()
  File "/usr/lib/python3.8/asyncio/base_events.py", line 1859, in _run_once
    handle._run()
  File "/usr/lib/python3.8/asyncio/events.py", line 81, in _run
    self._context.run(self._callback, *self._args)
  File "/home/arsham/.local/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 456, in dispatch_queue
    await self.process_one()
  File "/home/arsham/.local/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 445, in process_one
    await dispatch(*args)
  File "/home/arsham/.local/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 352, in dispatch_shell
    await result
  File "/home/arsham/.local/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 647, in execute_request
    reply_content = await reply_content
  File "/home/arsham/.local/lib/python3.8/site-packages/ipykernel/ipkernel.py", line 335, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/home/arsham/.local/lib/python3.8/site-packages/ipykernel/zmqshell.py", line 532, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/home/arsham/.local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2898, in run_cell
    result = self._run_cell(
  File "/home/arsham/.local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2944, in _run_cell
    return runner(coro)
  File "/home/arsham/.local/lib/python3.8/site-packages/IPython/core/async_helpers.py", line 68, in _pseudo_sync_runner
    coro.send(None)
  File "/home/arsham/.local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3169, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/home/arsham/.local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3361, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
  File "/home/arsham/.local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3441, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_13510/3867036866.py", line 15, in <module>
    output = model(mfccs)
  File "/home/arsham/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/arsham/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 298, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/arsham/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 294, in _conv_forward
    return F.conv1d(input, weight, bias, self.stride,
 (Triggered internally at  /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:104.)
  Variable._execution_engine.run_backward(
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_13510/3867036866.py in <module>
     17     probs = output
     18     _loss = loss(probs, labels, mfccs_lengths, label_lengths)
---> 19     _loss.backward()
     20     optimizer.step()
     21     scheduler.step()

~/.local/lib/python3.8/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    253                 create_graph=create_graph,
    254                 inputs=inputs)
--> 255         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    256 
    257     def register_hook(self, hook):

~/.local/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    145         retain_graph = create_graph
    146 
--> 147     Variable._execution_engine.run_backward(
    148         tensors, grad_tensors_, retain_graph, create_graph, inputs,
    149         allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag

RuntimeError: Function 'MkldnnConvolutionBackward' returned nan values in its 1th output.

@ptrblck I hope you assist me :frowning_face: