Performance Impact of `torch.autograd.profiler.record_function` in DataLoader's `_BaseDataLoaderIter` Iteration

:bug: Describe the bug

The current implementation of torch.autograd.profiler.record_function within the __next__ method of DataLoader’s _BaseDataLoaderIter class introduces noticeable performance overhead during iteration, especially in scenarios involving large datasets or frequent iterations.

Current Implementaion:

    def __next__(self) -> Any:
        with torch.autograd.profiler.record_function(self._profile_name):
            if self._sampler_iter is None:
                # TODO(https://github.com/pytorch/pytorch/issues/76750)
                self._reset()  # type: ignore[call-arg]
            data = self._next_data()
            self._num_yielded += 1
            if self._dataset_kind == _DatasetKind.Iterable and \
                    self._IterableDataset_len_called is not None and \
                    self._num_yielded > self._IterableDataset_len_called:
                warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
                            "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
                                                                  self._num_yielded)
                if self._num_workers > 0:
                    warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
                                 "IterableDataset replica at each worker. Please see "
                                 "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
                warnings.warn(warn_msg)
            return data

Test Code:

import torch
from time import time
from torch.utils.data import DataLoader, Dataset

num_items = 1000000

class MyDataset(Dataset):
    def __init__(self, data: torch.Tensor) -> None:
        self.data = data
    
    def __getitem__(self, index):
        return self.data[index]
    
    def __getitems__(self, indices):
        return tuple(self.data[indices])
    
    def __len__(self):
        return len(self.data)

dataset = MyDataset(torch.arange(num_items))

batch_sizes = [4, 16, 64, 256, 512]
shuffle = [False, True]

for shuffle in shuffle:
    print(" ")
    for batch_size in batch_sizes:
        dl = DataLoader(dataset, batch_size, shuffle)
        t0 = time()
        for _ in dl:
            pass
        t1 = time()
        print(f"{batch_size} | {shuffle} | time: {t1 - t0}")

Results (without modification):

4 | False | time: 11.467047691345215
16 | False | time: 4.080101490020752
64 | False | time: 2.109097957611084
256 | False | time: 1.6391916275024414
512 | False | time: 1.5716557502746582
 
4 | True | time: 11.690172910690308
16 | True | time: 4.2547852993011475
64 | True | time: 2.5096383094787598
256 | True | time: 1.8386955261230469
512 | True | time: 1.7621963024139404

Then I deleted with torch.autograd.profiler.record_function(self._profile_name): in __next__ method, making it:

    def __next__(self) -> Any:
        # with torch.autograd.profiler.record_function(self._profile_name):
        if self._sampler_iter is None:
            # TODO(https://github.com/pytorch/pytorch/issues/76750)
            self._reset()  # type: ignore[call-arg]
        data = self._next_data()
        self._num_yielded += 1
        if self._dataset_kind == _DatasetKind.Iterable and \
                self._IterableDataset_len_called is not None and \
                self._num_yielded > self._IterableDataset_len_called:
            warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
                        "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
                                                                self._num_yielded)
            if self._num_workers > 0:
                warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
                                "IterableDataset replica at each worker. Please see "
                                "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
            warnings.warn(warn_msg)
        return data

Then the result comes:

4 | False | time: 7.440187454223633
16 | False | time: 3.06815505027771
64 | False | time: 1.857062816619873
256 | False | time: 1.5879993438720703
512 | False | time: 1.5435562133789062
 
4 | True | time: 7.606332540512085
16 | True | time: 3.2116823196411133
64 | True | time: 2.0811235904693604
256 | True | time: 1.7673640251159668
512 | True | time: 1.7190146446228027

Based on the test results, it’s evident that including torch.autograd.profiler.record_function(self._profile_name) within the __next__ method of _BaseDataLoaderIter has a notable impact on iteration time, leading to slower performance compared to running the iteration without this profiling annotation. This overhead can affect the efficiency of DataLoader operations, particularly in scenarios with large datasets or frequent iterations.

I observed that torch.autograd.profiler.record_function is specifically utilized within the __next__ method of _BaseDataLoaderIter in torch/utils/data/dataloader.py. I’m curious about why profiling is necessary specifically in this context. In the rest of the DataLoader code, there are no any other calls to profiler. Could anyone please explain the rationale behind this choice?

Additionally, I’m wondering if there’s a possibility to either remove this line of code, or introduce a toggle switch to control whether profiling is enabled. This would provide users with more flexibility in managing the overhead introduced by profiling, especially in scenarios where fine-grained profiling is not required for every iteration.

These suggestions aim to optimize the performance and usability of PyTorch’s DataLoader module, particularly in environments where minimizing overhead is crucial for efficient data loading operations. Thank you for considering.

Versions

PyTorch version: 1.13.0+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.28.1
Libc version: glibc-2.31

Python version: 3.8.18 (default, Sep 11 2023, 13:40:15)  [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-1050-aws-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 535.104.12
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      46 bits physical, 48 bits virtual
CPU(s):                             16
On-line CPU(s) list:                0-15
Thread(s) per core:                 2
Core(s) per socket:                 8
Socket(s):                          1
NUMA node(s):                       1
Vendor ID:                          GenuineIntel
CPU family:                         6
Model:                              85
Model name:                         Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz
Stepping:                           7
CPU MHz:                            2499.996
BogoMIPS:                           4999.99
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          256 KiB
L1i cache:                          256 KiB
L2 cache:                           8 MiB
L3 cache:                           35.8 MiB
NUMA node0 CPU(s):                  0-15
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit:        KVM: Mitigation: VMX unsupported
Vulnerability L1tf:                 Mitigation; PTE Inversion
Vulnerability Mds:                  Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown:             Mitigation; PTI
Vulnerability Mmio stale data:      Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed:             Vulnerable
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Vulnerable
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke avx512_vnni

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.4
[pip3] torch==1.13.0+cu117
[pip3] torchdata==0.5.0
[pip3] torcheval==0.0.7
[pip3] torchmetrics==1.2.1
[conda] numpy                     1.24.4                   pypi_0    pypi
[conda] torch                     1.13.0+cu117             pypi_0    pypi
[conda] torchdata                 0.5.0                    pypi_0    pypi
[conda] torcheval                 0.0.7                    pypi_0    pypi
[conda] torchmetrics              1.2.1                    pypi_0    pypi

Thanks for raising the issue! Could you create a GitHub Issue so the code owners could take a look at it and chime in how representative the use case is, as I would guess the overhead is significantly smaller if “real” data loading is applied?

Thanks for your reply! I have created an issue here.

BTW could you please explain what “real” dataloading is? The test code I pasted above is rather simple, so the speedup ratio of removing torch.autograd.profiler.record_function seems less significant. But it was during the process of refactoring the ItemSampler component of dgl.graphbolt when I encountered this issue, and the speedup ratio of removing torch.autograd.profiler.record_function could be as high as 150% or even 200% for my code.

Your current DataLoader just indexes already loaded data without and loading and processing of it, which should already show an overhead compared to “pure” indexing. Of course the DataLoader will also batch, shuffle, etc. your data and is thus still a valid use case.
I was referring to lazy data loading when I mentioned “real” data loading, e.g. loading images from a local storage, decoding the images, and transforming them, which should add more overhead compared to indexing a tensor.