Error detected in CudnnBatchNormBackward

I am training a model. It ran smoothly for 10 epochs but at epoch 11 during 1022 iteration it gave an error. I think its related to batchnorm but I am confused how to tackle it. The augrad anomally detector give the following output. I am training in half precision mode with amp_opt_level 1

(BS 8) loss: 0.6229:  41%|█████▋        | 1022/2505 [1:01:46<1:29:42,  3.63s/it][W python_anomaly_mode.cpp:60] Warning: Error detected in CudnnBatchNormBackward. Traceback of forward call that caused the error:
  File "main.py", line 725, in <module>
    main()
  File "main.py", line 721, in main
    processor.start()
  File "main.py", line 657, in start
    self.train(epoch, save_model=save_model)
  File "main.py", line 502, in train
    output = self.model(batch_data)
  File "/home/linux/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/linux/phd_codes/models_pristine/MS-G3D_part_GCN_with_GRU_AE/model/msg3d.py", line 189, in forward
    x = F.relu(self.sgcn1(x) + self.gcn3d1(x), inplace=True)
  File "/home/linux/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/linux/anaconda3/lib/python3.8/site-packages/torch/nn/modules/container.py", line 117, in forward
    input = module(input)
  File "/home/linux/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/linux/phd_codes/models_pristine/MS-G3D_part_GCN_with_GRU_AE/model/ms_tcn.py", line 97, in forward
    out = tempconv(x)
  File "/home/linux/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/linux/anaconda3/lib/python3.8/site-packages/torch/nn/modules/container.py", line 117, in forward
    input = module(input)
  File "/home/linux/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/linux/anaconda3/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 131, in forward
    return F.batch_norm(
  File "/home/linux/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py", line 2014, in batch_norm
    return torch.batch_norm(
 (function print_stack)
(BS 8) loss: 0.6229:  41%|█████▋        | 1022/2505 [1:01:47<1:29:39,  3.63s/it]
Traceback (most recent call last):
  File "main.py", line 725, in <module>
    main()
  File "main.py", line 721, in main
    processor.start()
  File "main.py", line 657, in start
    self.train(epoch, save_model=save_model)
  File "main.py", line 513, in train
    scaled_loss.backward()
  File "/home/linux/anaconda3/lib/python3.8/site-packages/torch/tensor.py", line 185, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/linux/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 125, in backward

Is this the full stacktrace? It looks like the output is truncated.

This is just the part of the training output where the error occured and my training stopped as I set autograd_anomaly = True. Until epoch 11 training was smooth.

[ Tue Jun  1 07:36:56 2021 ] Training epoch: 10, LR: 0.0500
(BS 8) loss: 0.4095:  45%|██████▎       | 1127/2505 [1:07:54<1:23:12,  3.62s/it]Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 131072.0
(BS 8) loss: 0.5079:  90%|██████████████▎ | 2243/2505 [2:15:27<15:52,  3.63s/it]Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 131072.0
(BS 8) loss: 0.0766: 100%|████████████████| 2505/2505 [2:31:18<00:00,  3.62s/it]
[ Tue Jun  1 10:08:15 2021 ] 	Mean training loss: 0.4249 (BS 16: 0.8497).
[ Tue Jun  1 10:08:15 2021 ] 	Time consumption: [Data]00%, [Network]99%
[ Tue Jun  1 10:08:15 2021 ] Eval epoch: 10
100%|███████████████████████████████████████| 1031/1031 [03:59<00:00,  4.31it/s]
Accuracy:  0.7487717595681446  model:  msg3d_with_part
[ Tue Jun  1 10:12:14 2021 ] 	Mean test loss of 1031 batches: 0.8397054561819415.
[ Tue Jun  1 10:12:15 2021 ] 	Top 1: 74.88%
[ Tue Jun  1 10:12:15 2021 ] 	Top 5: 94.84%
[ Tue Jun  1 10:12:15 2021 ] Training epoch: 11, LR: 0.0500
(BS 8) loss: 0.6229:  41%|█████▋        | 1022/2505 [1:01:46<1:29:42,  3.63s/it][W python_anomaly_mode.cpp:60] Warning: Error detected in CudnnBatchNormBackward. Traceback of forward call that caused the error:
  File "main.py", line 725, in <module>
    main()