Seeking Optimization for Speed Bottleneck in 3D Data Processing

Hi everyone,

Lately, I’ve been immersed in the realm of 3D data processing, yet I’ve encountered a significant speed bottleneck attributed to the .cuda() operation within my workflow.

To provide some context, I’m working with batches of volumetric data with a shape of (16, 8, 128, 128, 128), which need to be transferred to CUDA for subsequent deep learning processes. However, I’ve observed that the .cuda() operation alone can consume approximately 0.5 seconds on my device. This duration is strikingly similar to the time it takes for my neural network to complete a forward pass (~0.5 seconds) (the simplified code snippet below):

import torch
import time

# CUDA warm up
a = torch.ones([16, 8, 128, 128, 128])
a = a.cuda()

torch.cuda.synchronize()
start = time.time()

a = torch.ones([16, 8, 128, 128, 128])
a = a.cuda()

torch.cuda.synchronize()
print(time.time() - start)
# ~0.5s in my device

Despite exploring various avenues for optimization, such as increasing the num_workers parameter, I’ve found no tangible improvements since the bottleneck isn’t directly linked to the dataset’s __getitem()__ operation. While enabling pin_memory did yield some speed enhancements, the gains were marginal, barely reducing the time to approximately 0.3 seconds.

In an attempt to leverage the multiprocessing capabilities of PyTorch’s DataLoader, I experimented with pre-loading data onto the GPU within the collate_fn() function (the simplified code snippet below). Remarkably, this approach resulted in significantly faster processing times (~0.004 seconds) when combined with an appropriate num_workers setting. However, it introduced unexpected anomalies, such as all tensors within the first batch being zero-valued, followed by subsequent batches exhibiting a mix of correct and zero-valued tensors. (I have no idea why these happen)

def collate_fn(batch, device):
    # Transfer sdf_points to cuda
    batch_sdf_points = torch.stack([item[0] for item in batch], dim=0).type(torch.float32).to(device, non_blocking=True)
    return batch_sdf_points

dataload = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            drop_last=True,
            persistent_workers=True,
            collate_fn=partial(collate_fn, device='cuda')
            )

Given these challenges, I’m reaching out to seek any insights or recommendations regarding further optimizations to expedite this process. Any assistance or guidance would be immensely appreciated.

Hi,

You might want to use non_blocking=True in .to(device=) or .cuda() in conjunction with pin_memory if your are transferring multiple of these large tensors in succession. If only once I don’t think it will help. Also it will make profiling more difficult.

About transferring to GPU in collate_fn, it’s just a guess, but I’d look around possible synchronization issues. For example: what if the asynchronous memory transfer in the data worker is still going on while the tensor handle has already been transferred to the train process and the forward pass has started?

Hi Nicolas,

Many thanks for your help!

I used Pytorch-lightning in my code, and non_blocking seems always activated. Thus, there may be no further benefit I can get there.

what if the asynchronous memory transfer in the data worker is still going on while the tensor handle has already been transferred to the train process and the forward pass has started?

Ah! It sounds reasonable why I encountered those weird phenomena. Under this conjecture, the collated tensors seem to be always initialized with zero values.

Maybe I need an explicit torch.cuda.synchronize() at the end of collate_fn(). As I’m not quite familiar with cuda operations, is there any potential by-effect of doing these cuda operations in multiprocessing?

is there any potential by-effect of doing these cuda operations in multiprocessing?

I was wondering the same. To be sure, you can try to do the .cuda() on a separate cuda stream. Something in the line of:

s = torch.cuda.Stream(device=??)
with torch.cuda.stream(s):
    batch = batch.pin_memory().to(device=??, non_blocking=True)
s.synchronize()