Why can‘t pytorch profiler capture activities of data preprocess in __getitem__()?

Profiling data preprocessing activities in getitem

Hi all,

I’m currently using torch.profiler to profile the performance of my data loading process. However, I’m running into an issue where the profiler seems to aggregate the entire data loading process under a single enumerate(dataloader) operation, rather than breaking it down into the more granular record_function calls that I’ve placed inside my __getitem__ method.

Here’s the relevant part of my code:

import torch
from torch.utils.data import Dataset, DataLoader
from torch.profiler import profile, record_function, ProfilerActivity

class KiTSDataset(Dataset):
    def __init__(self, cases, crop_size=None, fixed_depth=None):
        self.cases = cases
        self.crop_size = crop_size
        self.fixed_depth = fixed_depth
    
    def __len__(self):
        return len(self.cases)

    def crop_center(self, volume, crop_size):
        _, d, h, w = volume.shape
        start_d = (d - crop_size[0]) // 2
        start_h = (h - crop_size[1]) // 2
        start_w = (w - crop_size[2]) // 2
        return volume[:, start_d:start_d + crop_size[0], start_h:start_h + crop_size[1], start_w:start_w + crop_size[2]]

    def pad_or_crop_depth(self, volume):
        depth = volume.shape[1]
        if depth < self.fixed_depth:
            pad_size = self.fixed_depth - depth
            volume = F.pad(volume, (0, 0, 0, 0, pad_size // 2, pad_size - pad_size // 2))
        elif depth > self.fixed_depth:
            start = (depth - self.fixed_depth) // 2
            volume = volume[:, start:start + self.fixed_depth, :, :]
        return volume

    def __getitem__(self, idx):
        case = self.cases[idx]
        with record_function("loadcase"):
            volume, segmentation = load_case(case)
        
        with record_function("ToTensor"):
            volume = torch.tensor(volume.get_fdata(), dtype=torch.float32).unsqueeze(0)
            segmentation = torch.tensor(segmentation.get_fdata(), dtype=torch.long).unsqueeze(0)

        if self.crop_size:
            with record_function("CropCenter"):
                volume = self.crop_center(volume, self.crop_size)
                segmentation = self.crop_center(segmentation, self.crop_size)

        if self.fixed_depth:
            with record_function("PadOrCropDepth"):
                volume = self.pad_or_crop_depth(volume)
                segmentation = self.pad_or_crop_depth(segmentation)

        return volume, segmentation.squeeze(0)

def main():
    cases = [f'case_{i:05d}' for i in range(100)]
    dataset = KiTSDataset(cases, crop_size=(100, 256, 256), fixed_depth=10)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=2)

    with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        record_shapes=True,
        profile_memory=True,
        with_stack=True
    ) as prof:
        for data in dataloader:
            with record_function("model_inference"):
                # Simulate model inference
                pass
        prof.step()

    print(prof.key_averages().table(sort_by="self_cpu_time_total"))

if __name__ == '__main__':
    main()

Here is a snapshot of the profiler’s output:

@ptrblck Hi, can you please help me out here. Or is there possible to use nsys to profile every activity of data preprocess?

I don’t think the profiler works in a multi-processing environment and you might need to use num_workers=0 to profile the data loading pipeline.
Here is a simple example showing this behavior:

class MyDataset(Dataset):
    def __init__(self):
        self.data = torch.randn(10, 3, 224, 224)
        self.target = torch.randint(0, 10, (10,))
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]
        
        with record_function("transform1"):
                x = x * 2
                
        with record_function("transform2"):
                y = y + 1
        
        return x, y
    
dataset = MyDataset()
loader = DataLoader(dataset, batch_size=5, num_workers=0)

with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("loader"):
        for data, target in loader:
            pass

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=15))
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                                                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                                                  loader        10.30%     306.577us       100.00%       2.978ms       2.978ms             1  
# enumerate(DataLoader)#_SingleProcessDataLoaderIter._...         8.93%     266.004us        89.21%       2.657ms     885.566us             3  
#                                              transform1         6.86%     204.200us        56.38%       1.679ms     167.883us            10  
#                                               aten::mul        47.42%       1.412ms        49.52%       1.475ms     147.463us            10  
#                                             aten::stack         0.70%      20.700us        16.14%     480.542us     120.135us             4  
#                                               aten::cat        14.62%     435.316us        15.04%     447.960us     111.990us             4  
#                                              transform2         4.75%     141.413us         5.67%     168.766us      16.877us            10  
#                                                aten::to         0.30%       8.915us         2.10%      62.627us       6.263us            10  
#                                            aten::select         1.73%      51.482us         2.10%      62.556us       3.128us            20  
#                                          aten::_to_copy         0.81%      24.018us         1.80%      53.712us       5.371us            10  
#                                               aten::add         0.92%      27.353us         0.92%      27.353us       2.735us            10  
#                                             aten::copy_         0.68%      20.229us         0.68%      20.229us       2.023us            10  
#                                        aten::as_strided         0.48%      14.319us         0.48%      14.319us       0.447us            32  
#                                            aten::narrow         0.20%       5.941us         0.42%      12.644us       6.322us             2  
#                                     aten::empty_strided         0.32%       9.465us         0.32%       9.465us       0.947us            10  
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
# Self CPU time total: 2.978ms

Thanks, that helps a lot.