Profiling data preprocessing activities in getitem
Hi all,
I’m currently using torch.profiler
to profile the performance of my data loading process. However, I’m running into an issue where the profiler seems to aggregate the entire data loading process under a single enumerate(dataloader)
operation, rather than breaking it down into the more granular record_function
calls that I’ve placed inside my __getitem__
method.
Here’s the relevant part of my code:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.profiler import profile, record_function, ProfilerActivity
class KiTSDataset(Dataset):
def __init__(self, cases, crop_size=None, fixed_depth=None):
self.cases = cases
self.crop_size = crop_size
self.fixed_depth = fixed_depth
def __len__(self):
return len(self.cases)
def crop_center(self, volume, crop_size):
_, d, h, w = volume.shape
start_d = (d - crop_size[0]) // 2
start_h = (h - crop_size[1]) // 2
start_w = (w - crop_size[2]) // 2
return volume[:, start_d:start_d + crop_size[0], start_h:start_h + crop_size[1], start_w:start_w + crop_size[2]]
def pad_or_crop_depth(self, volume):
depth = volume.shape[1]
if depth < self.fixed_depth:
pad_size = self.fixed_depth - depth
volume = F.pad(volume, (0, 0, 0, 0, pad_size // 2, pad_size - pad_size // 2))
elif depth > self.fixed_depth:
start = (depth - self.fixed_depth) // 2
volume = volume[:, start:start + self.fixed_depth, :, :]
return volume
def __getitem__(self, idx):
case = self.cases[idx]
with record_function("loadcase"):
volume, segmentation = load_case(case)
with record_function("ToTensor"):
volume = torch.tensor(volume.get_fdata(), dtype=torch.float32).unsqueeze(0)
segmentation = torch.tensor(segmentation.get_fdata(), dtype=torch.long).unsqueeze(0)
if self.crop_size:
with record_function("CropCenter"):
volume = self.crop_center(volume, self.crop_size)
segmentation = self.crop_center(segmentation, self.crop_size)
if self.fixed_depth:
with record_function("PadOrCropDepth"):
volume = self.pad_or_crop_depth(volume)
segmentation = self.pad_or_crop_depth(segmentation)
return volume, segmentation.squeeze(0)
def main():
cases = [f'case_{i:05d}' for i in range(100)]
dataset = KiTSDataset(cases, crop_size=(100, 256, 256), fixed_depth=10)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=2)
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True
) as prof:
for data in dataloader:
with record_function("model_inference"):
# Simulate model inference
pass
prof.step()
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
if __name__ == '__main__':
main()
Here is a snapshot of the profiler’s output: