How to prefetch data when processing with GPU?

Thank you @SimonW, I modified the code as following:

CUDA_LAUNCH_BLOCKING = 1

for workers in [0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32, 40, 48, 64, 96]:
	model = ResNet(BasicBlock)
	model = torch.nn.DataParallel(model).cuda()
	criterion = nn.CrossEntropyLoss().cuda()
	optimizer = torch.optim.SGD(model.parameters(), 0.001)

	train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True,  num_workers=workers, pin_memory=False)

	model.train()
	data_times = []
	compute_times = []
	other_times = []
	backprop_times = []
	total_times = []
	torch.cuda.synchronize()
	start = time.time()
	end = start
	for i, (images, target) in enumerate(train_loader):
		target = target.cuda(non_blocking=True)
		torch.cuda.synchronize()
		end_data = time.time()
		data_time = end_data - end
		output = model(images)
		torch.cuda.synchronize()
		end_compute = time.time()
		compute_time = end_compute - end_data
		loss = criterion(output, target)
		acc = accuracy(output, target)
		torch.cuda.synchronize()
		end_other = time.time()
		other_time = end_other - end_compute
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
		torch.cuda.synchronize()
		end_backprop = time.time()
		backprop_time = end_backprop - end_other
		total_time = data_time + compute_time + other_time + backprop_time

		if i != 0:
			data_times.append(data_time)
			compute_times.append(compute_time)
			other_times.append(other_time)
			backprop_times.append(backprop_time)
			total_times.append(total_time)

		torch.cuda.synchronize()
		end = time.time()

And here are the results:

 0              5.5             1.45/0.06/0.00/0.07
 1              5.1             1.39/0.06/0.00/0.07
 2              2.6             0.64/0.07/0.00/0.08
 4              1.4             0.26/0.07/0.00/0.08
 6              1.1             0.16/0.07/0.00/0.08
 8              1.0             0.13/0.07/0.00/0.08
10              0.9             0.10/0.08/0.00/0.08
12              0.8             0.05/0.11/0.00/0.08
14              0.9             0.04/0.11/0.00/0.08
16              0.9             0.04/0.12/0.00/0.08
20              0.9             0.04/0.12/0.00/0.08
24              1.0             0.04/0.13/0.00/0.08
28              1.0             0.04/0.13/0.00/0.08
32              1.0             0.05/0.13/0.00/0.08
40              1.1             0.05/0.14/0.00/0.08
48              1.2             0.05/0.14/0.00/0.08
64              1.3             0.06/0.15/0.00/0.08
96              1.6             0.07/0.17/0.00/0.08

The “other” time is now negligible, which is expected, however it’s still not clear why the compute time is increasing with number of workers, while the backprop time stays roughly constant.

1 Like

There is a way to prefetch data between cpu and gpu by cudaMemAdvise and cudaMemPrefetchAsync. I am wondering that is this has been intergrated in to dataloader. I found a flag prefetch_factor in dataloader constructor, not sure if it is the one. If not, how can I integrated it?

Hi,

PyTorch does not use unified memory on GPU (because it has too many performance pitfalls to be useful in the DL case). So these two functions have no effect for PyTorch programs.

Hi Soumith, following up on this old reply: Is the torchvision.transforms pipeline run by num_worker threads in parallel or by a single thread?

I have written a custom torchvision transform for image augmentation using skimage and composed it in the loader. Unfortunately augmentation is taking forever despite varying num_workers, pin_memory, persist, and prefetch values.

The transformations themselves do not define, if multiple processes are used etc.
Instead the DataLoader allows you to use multiple workers by setting num_workers>0, which will create copies of the underlying Dataset and execute them using multiple processes.
Each process will create batches of data in the background by calling into Dataset.__getitem__ and applying the transformation.
If you see a performance decrease after adding your augmentation, you could profile these methods and check, if you could speed them up by using a faster implementation (from another library).

2 Likes