Another nan problem while training

I use a modified DCCRN model to learn denoising. I got sudden Nan durig training process and I set logger in the train process to track if parameters/gradient has inf/nan value and output the max/min value on each layers like:

if is_training:
      self._backward_batch(loss)
      # The learning rate scheduler.step MUST be called before
      # optimizer step for cold start
      self.lr_scheduler.step()
      with self.timer['merge']:
         try:
          for name, param in self.model.named_parameters():
              logger.info("{} is the maximum gradient on {} layer"\
              .format(torch.max(param.grad), name))
              logger.info("{} is the minimum gradient on {} layer"\
              .format(torch.min(param.grad),name)) 
              logger.info("{} layer has nan gradient ??|{}"\
              .format(name, torch.isnan(param.grad).any()))
              logger.info("{} layer has infinite gradient ??|{}"\
              .format(name, torch.isfinite(param.grad).any())) 
                           
          self.dist_optim.step(frames)
         except EarlyStop:
            data_queue.clear()
            break
         
         sd=self.model.state_dict()
          #  for param_tensor in sd:
         logger.info("{} layer has nan parameter??|{}"\
         .format(param_tensor,torch.isnan(sd[param_tensor]).any()))
         logger.info("{} is maximum parameter in {}"\
         .format(torch.max(sd[param_tensor])))
         logger.info("{} is min parameter in {}"\
         .format(torch.min(sd[param_tensor])))
            
            # the last interval will be retained
            self.timer['wall'].checkpoint()
            self.timer['wall'].start()

            epoch_metric.accumulate(loss_statistics)
            local_metric.accumulate(loss_statistics)

I have also set logger on the input to make sure if my input doesn’t have abnormal value ( and it does actually )
I choose batch_size of 4. Each machine ( 16 in total ) has about 9000 train data, the nan occurs after 47000 round mini-batches (around 16 epochs). I did experiment 2 time, I discover that the first time that nan occurs is always on the paramters (after self.dist_optim.step(frames), always on a specific layers and all parameters went into nan at the same time. and the gradient before all parameters going into nan is never abnormal.

I would like to understand why the nan occurs in my case. I could show some amusing factoid on the matter, here rank0 ,rank1, rank2 just means machine 1 machine 2 machine 3 …