Optimizer freezes for a while after every 11 iterations

Hi there, I’m running some profiling on my training code and I realized that every 11 iterations (batches) the code freezes for a while when calling optim.step then it resumes normally for the next 11 iterations. Does it have to do with gradient checking or maybe a bug in my code?

The training loop is as follows:

for i, (index, im, mask) in tqdm(enumerate(train_loader), total=n_iter):
    im = im.cuda()
    mask = mask.cuda()
    # Forward propagation
    logit = self.forward(im)
    loss = self.criterion(logit, mask)

    # Compute metrics
    if i % metric_step == 0:
        pred = torch.sigmoid(logit)
        iou  = eval.dice_accuracy(pred.data.cpu().numpy(),

    self.optimizer.step()  # <- freezes here on iterations multiples of 11

Many thanks.

Kind regards,

Are you using this number for some parameters, e.g. metric_step, batch_size etc.?
Based on this code I can’t see obvious reasons for the freeze.

No, there is no parameter equals to 11 . The issue is not the metrics because I’ve observed that it freezes on self.optimizer.step(), I’ve also tried to avoid computing the metrics completely.
About batch_size I’ve tried 32 and 50 . It seems that with 32 the freeze is a little shorter. I’m away from computer now but if I’m not mistaken I’m using Pytorch 0.4.x. Could it be a bug and should I try to update Pytorch?


Could it be that your code freezes in each epoch?
Are you using multiple workers in your DataLoader? If so, are you preloading the data in the Dataset's __init__ or are you lazily loading the data?

1 Like

It is freezing every 11 batches, way before completing an epoch. I’m haven’t paid attention at the end of the epoch.

Yes, my DataLoaderhave multiple workers, 11 (Ooooh! Eureka moment) if I’m not mistaken. It loads the image when processing the batch because I have a lot of data. This must be the culprit, is there a way to asynchronously process the next batch in CPU while GPU is training the current batch?

Kind regards,

The DataLoader uses multiprocessing to load the batches asynchronously while the training takes place.
However, if you have some heavy preprocessing in your Dataset or the data loading is IO bound, you might notice small freezes as the workers can’t keep up processing the data fast enough.

You could try to increase the number of workers, store your data on an SSD (if that’s not the case already), or try to speed up the preprocessing.

1 Like

I was having exactly the same problem: every couple of epochs, the progress would freeze for a few seconds.

After trying many things, I could fix this issue by wrapping my dataloader in a prefetcher and preprocessing the data (in my case, frames of videos) on the GPU. This fixed the problem and resulted in a 6x speedup!

Before, I was preprocessing my frames inside the __getitem__ function of the dataset. Pretty much like shown in the data loading tutorial. The only preprocessing I was doing was a conversion to float (my raw data is uint8) and a resize (shrinking).

So I changed this. Inside __getitem__ remained just the code that reads the datapoint from the harddrive. And then I added the DataPrefetcher class, which is both a wrapper and a drop-in replacement for the dataloader.

import torch

class DataPrefetcher():
    def __init__(self, dataloader, img_shape, device):
        self.dataloader = dataloader
        self._len = len(dataloader)
        self.device = device
        self.stream = torch.cuda.Stream()
        self.img_shape = img_shape

    def prefetch(self):
            self.next_video, self.next_label = next(self.dl_iter)
        except StopIteration:
            self.next_video = None
            self.next_label = None
        with torch.cuda.stream(self.stream):
            self.next_label = self.next_label.to(self.device, non_blocking=True)
            self.next_video = self.next_video.to(self.device, non_blocking=True)

            self.next_video = self.next_video.float()
            self.next_video = torch.nn.functional.interpolate(

    def __iter__(self):
        self.dl_iter = iter(self.dataloader)
        return self

    def __len__(self):
        return self._len

    def __next__(self):
        video = self.next_video
        label = self.next_label

        if video is None or label is None:
            raise StopIteration

        return video, label

(This code was plagiarised adapted from apex/amp)

Another advantage of this approach is that F.interpolate will resize all frames of the batch “at once”. Before, I had to loop over the frames of a video and resize them one by one (I was using scikit-image for that).

Some extra remarks concerning the DataLoader. In my case, I got the best performance with:
pin_memory = False
num_workers = min(n_cores, batch_size)

1 Like