Hi everyone,
Lately, I’ve been immersed in the realm of 3D data processing, yet I’ve encountered a significant speed bottleneck attributed to the .cuda()
operation within my workflow.
To provide some context, I’m working with batches of volumetric data with a shape of (16, 8, 128, 128, 128), which need to be transferred to CUDA for subsequent deep learning processes. However, I’ve observed that the .cuda()
operation alone can consume approximately 0.5 seconds on my device. This duration is strikingly similar to the time it takes for my neural network to complete a forward pass (~0.5 seconds) (the simplified code snippet below):
import torch
import time
# CUDA warm up
a = torch.ones([16, 8, 128, 128, 128])
a = a.cuda()
torch.cuda.synchronize()
start = time.time()
a = torch.ones([16, 8, 128, 128, 128])
a = a.cuda()
torch.cuda.synchronize()
print(time.time() - start)
# ~0.5s in my device
Despite exploring various avenues for optimization, such as increasing the num_workers
parameter, I’ve found no tangible improvements since the bottleneck isn’t directly linked to the dataset’s __getitem()__
operation. While enabling pin_memory
did yield some speed enhancements, the gains were marginal, barely reducing the time to approximately 0.3 seconds.
In an attempt to leverage the multiprocessing capabilities of PyTorch’s DataLoader
, I experimented with pre-loading data onto the GPU within the collate_fn()
function (the simplified code snippet below). Remarkably, this approach resulted in significantly faster processing times (~0.004 seconds) when combined with an appropriate num_workers
setting. However, it introduced unexpected anomalies, such as all tensors within the first batch being zero-valued, followed by subsequent batches exhibiting a mix of correct and zero-valued tensors. (I have no idea why these happen)
def collate_fn(batch, device):
# Transfer sdf_points to cuda
batch_sdf_points = torch.stack([item[0] for item in batch], dim=0).type(torch.float32).to(device, non_blocking=True)
return batch_sdf_points
dataload = torch.utils.data.DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
drop_last=True,
persistent_workers=True,
collate_fn=partial(collate_fn, device='cuda')
)
Given these challenges, I’m reaching out to seek any insights or recommendations regarding further optimizations to expedite this process. Any assistance or guidance would be immensely appreciated.