I was having exactly the same problem: every couple of epochs, the progress would freeze for a few seconds.
After trying many things, I could fix this issue by wrapping my dataloader in a prefetcher and preprocessing the data (in my case, frames of videos) on the GPU. This fixed the problem and resulted in a 6x speedup!
Before, I was preprocessing my frames inside the __getitem__
function of the dataset. Pretty much like shown in the data loading tutorial. The only preprocessing I was doing was a conversion to float (my raw data is uint8) and a resize (shrinking).
So I changed this. Inside __getitem__
remained just the code that reads the datapoint from the harddrive. And then I added the DataPrefetcher class, which is both a wrapper and a drop-in replacement for the dataloader.
import torch
class DataPrefetcher():
def __init__(self, dataloader, img_shape, device):
self.dataloader = dataloader
self._len = len(dataloader)
self.device = device
torch.cuda.device(device)
self.stream = torch.cuda.Stream()
self.img_shape = img_shape
def prefetch(self):
try:
self.next_video, self.next_label = next(self.dl_iter)
except StopIteration:
self.next_video = None
self.next_label = None
return
with torch.cuda.stream(self.stream):
self.next_label = self.next_label.to(self.device, non_blocking=True)
self.next_video = self.next_video.to(self.device, non_blocking=True)
self.next_video = self.next_video.float()
self.next_video = torch.nn.functional.interpolate(
input=self.next_video,
size=self.img_shape,
mode="trilinear",
align_corners=False,
)
def __iter__(self):
self.dl_iter = iter(self.dataloader)
self.prefetch()
return self
def __len__(self):
return self._len
def __next__(self):
torch.cuda.current_stream().wait_stream(self.stream)
video = self.next_video
label = self.next_label
if video is None or label is None:
raise StopIteration
video.record_stream(torch.cuda.current_stream())
label.record_stream(torch.cuda.current_stream())
self.prefetch()
return video, label
(This code was plagiarised adapted from apex/amp)
Another advantage of this approach is that F.interpolate will resize all frames of the batch “at once”. Before, I had to loop over the frames of a video and resize them one by one (I was using scikit-image for that).
Some extra remarks concerning the DataLoader. In my case, I got the best performance with:
pin_memory = False
num_workers = min(n_cores, batch_size)