PyTorch DDP - Large amount of time on optimizer.step on some batches randomly

I’m using PyTorch DDP to train a ResNet101 over 4 p3.16xlarge EC2 instances, each with 8 GPUs. What I notice is that the training goes well till some epochs, and then it slows down suddenly, because of some batches taking a huge amount of time (~3000 seconds). I profiled, and found out that the slow down always happens due to the optimizer.step() operation. The 4 instances are in the same availability zone, same subnet and same cluster placement group. And sometimes, the training is slow (spikes in some batches) right from the start. This is not a data issue as the slowdown is not due to the batch loading time, and because the slowdown is not in all epochs. Any leads on resolving this would be greatly appreciated.

How did you profile the code and narrowed down the optimizer.step() method?
Are you sure the actual step() call is causing the slowdown or would the step() function rather have to wait for other nodes to finish its backward pass before it can start updating the parameters?

Thanks for your response! I just used time.time() before and after the statement and then logged it onto tensorboard. It could include time for when the step() function blocked. Also, if I use AdamW instead of SGD it’s the backward() instead of step().