I use a modified DCCRN model to learn denoising. I got sudden Nan durig training process and I set logger in the train process to track if parameters/gradient has inf/nan value and output the max/min value on each layers like:
if is_training:
self._backward_batch(loss)
# The learning rate scheduler.step MUST be called before
# optimizer step for cold start
self.lr_scheduler.step()
with self.timer['merge']:
try:
for name, param in self.model.named_parameters():
logger.info("{} is the maximum gradient on {} layer"\
.format(torch.max(param.grad), name))
logger.info("{} is the minimum gradient on {} layer"\
.format(torch.min(param.grad),name))
logger.info("{} layer has nan gradient ??|{}"\
.format(name, torch.isnan(param.grad).any()))
logger.info("{} layer has infinite gradient ??|{}"\
.format(name, torch.isfinite(param.grad).any()))
self.dist_optim.step(frames)
except EarlyStop:
data_queue.clear()
break
sd=self.model.state_dict()
# for param_tensor in sd:
logger.info("{} layer has nan parameter??|{}"\
.format(param_tensor,torch.isnan(sd[param_tensor]).any()))
logger.info("{} is maximum parameter in {}"\
.format(torch.max(sd[param_tensor])))
logger.info("{} is min parameter in {}"\
.format(torch.min(sd[param_tensor])))
# the last interval will be retained
self.timer['wall'].checkpoint()
self.timer['wall'].start()
epoch_metric.accumulate(loss_statistics)
local_metric.accumulate(loss_statistics)
I have also set logger on the input to make sure if my input doesn’t have abnormal value ( and it does actually )
I choose batch_size of 4. Each machine ( 16 in total ) has about 9000 train data, the nan occurs after 47000 round mini-batches (around 16 epochs). I did experiment 2 time, I discover that the first time that nan occurs is always on the paramters (after self.dist_optim.step(frames), always on a specific layers and all parameters went into nan at the same time. and the gradient before all parameters going into nan is never abnormal.
I would like to understand why the nan occurs in my case. I could show some amusing factoid on the matter, here rank0 ,rank1, rank2 just means machine 1 machine 2 machine 3 …