Hi there, I’m running some profiling on my training code and I realized that every 11 iterations (batches) the code freezes for a while when calling
optim.step then it resumes normally for the next 11 iterations. Does it have to do with gradient checking or maybe a bug in my code?
The training loop is as follows:
for i, (index, im, mask) in tqdm(enumerate(train_loader), total=n_iter):
im = im.cuda()
mask = mask.cuda()
# Forward propagation
logit = self.forward(im)
loss = self.criterion(logit, mask)
# Compute metrics
if i % metric_step == 0:
pred = torch.sigmoid(logit)
iou = eval.dice_accuracy(pred.data.cpu().numpy(),
self.optimizer.step() # <- freezes here on iterations multiples of 11
Are you using this number for some parameters, e.g.
Based on this code I can’t see obvious reasons for the freeze.
No, there is no parameter equals to 11 . The issue is not the metrics because I’ve observed that it freezes on
self.optimizer.step(), I’ve also tried to avoid computing the metrics completely.
batch_size I’ve tried 32 and 50 . It seems that with 32 the freeze is a little shorter. I’m away from computer now but if I’m not mistaken I’m using Pytorch 0.4.x. Could it be a bug and should I try to update Pytorch?
Could it be that your code freezes in each epoch?
Are you using multiple workers in your
DataLoader? If so, are you preloading the data in the
__init__ or are you lazily loading the data?
It is freezing every 11 batches, way before completing an epoch. I’m haven’t paid attention at the end of the epoch.
DataLoaderhave multiple workers, 11 (Ooooh! Eureka moment) if I’m not mistaken. It loads the image when processing the batch because I have a lot of data. This must be the culprit, is there a way to asynchronously process the next batch in CPU while GPU is training the current batch?
DataLoader uses multiprocessing to load the batches asynchronously while the training takes place.
However, if you have some heavy preprocessing in your
Dataset or the data loading is IO bound, you might notice small freezes as the workers can’t keep up processing the data fast enough.
You could try to increase the number of workers, store your data on an SSD (if that’s not the case already), or try to speed up the preprocessing.
I was having exactly the same problem: every couple of epochs, the progress would freeze for a few seconds.
After trying many things, I could fix this issue by wrapping my dataloader in a prefetcher and preprocessing the data (in my case, frames of videos) on the GPU. This fixed the problem and resulted in a 6x speedup!
Before, I was preprocessing my frames inside the
__getitem__ function of the dataset. Pretty much like shown in the data loading tutorial. The only preprocessing I was doing was a conversion to float (my raw data is uint8) and a resize (shrinking).
So I changed this. Inside
__getitem__ remained just the code that reads the datapoint from the harddrive. And then I added the DataPrefetcher class, which is both a wrapper and a drop-in replacement for the dataloader.
def __init__(self, dataloader, img_shape, device):
self.dataloader = dataloader
self._len = len(dataloader)
self.device = device
self.stream = torch.cuda.Stream()
self.img_shape = img_shape
self.next_video, self.next_label = next(self.dl_iter)
self.next_video = None
self.next_label = None
self.next_label = self.next_label.to(self.device, non_blocking=True)
self.next_video = self.next_video.to(self.device, non_blocking=True)
self.next_video = self.next_video.float()
self.next_video = torch.nn.functional.interpolate(
self.dl_iter = iter(self.dataloader)
video = self.next_video
label = self.next_label
if video is None or label is None:
return video, label
(This code was
plagiarised adapted from apex/amp)
Another advantage of this approach is that F.interpolate will resize all frames of the batch “at once”. Before, I had to loop over the frames of a video and resize them one by one (I was using scikit-image for that).
Some extra remarks concerning the DataLoader. In my case, I got the best performance with:
pin_memory = False
num_workers = min(n_cores, batch_size)