Describe the bug
The current implementation of torch.autograd.profiler.record_function
within the __next__
method of DataLoader’s _BaseDataLoaderIter
class introduces noticeable performance overhead during iteration, especially in scenarios involving large datasets or frequent iterations.
Current Implementaion:
def __next__(self) -> Any:
with torch.autograd.profiler.record_function(self._profile_name):
if self._sampler_iter is None:
# TODO(https://github.com/pytorch/pytorch/issues/76750)
self._reset() # type: ignore[call-arg]
data = self._next_data()
self._num_yielded += 1
if self._dataset_kind == _DatasetKind.Iterable and \
self._IterableDataset_len_called is not None and \
self._num_yielded > self._IterableDataset_len_called:
warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
"samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
self._num_yielded)
if self._num_workers > 0:
warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
"IterableDataset replica at each worker. Please see "
"https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
warnings.warn(warn_msg)
return data
Test Code:
import torch
from time import time
from torch.utils.data import DataLoader, Dataset
num_items = 1000000
class MyDataset(Dataset):
def __init__(self, data: torch.Tensor) -> None:
self.data = data
def __getitem__(self, index):
return self.data[index]
def __getitems__(self, indices):
return tuple(self.data[indices])
def __len__(self):
return len(self.data)
dataset = MyDataset(torch.arange(num_items))
batch_sizes = [4, 16, 64, 256, 512]
shuffle = [False, True]
for shuffle in shuffle:
print(" ")
for batch_size in batch_sizes:
dl = DataLoader(dataset, batch_size, shuffle)
t0 = time()
for _ in dl:
pass
t1 = time()
print(f"{batch_size} | {shuffle} | time: {t1 - t0}")
Results (without modification):
4 | False | time: 11.467047691345215
16 | False | time: 4.080101490020752
64 | False | time: 2.109097957611084
256 | False | time: 1.6391916275024414
512 | False | time: 1.5716557502746582
4 | True | time: 11.690172910690308
16 | True | time: 4.2547852993011475
64 | True | time: 2.5096383094787598
256 | True | time: 1.8386955261230469
512 | True | time: 1.7621963024139404
Then I deleted with torch.autograd.profiler.record_function(self._profile_name):
in __next__
method, making it:
def __next__(self) -> Any:
# with torch.autograd.profiler.record_function(self._profile_name):
if self._sampler_iter is None:
# TODO(https://github.com/pytorch/pytorch/issues/76750)
self._reset() # type: ignore[call-arg]
data = self._next_data()
self._num_yielded += 1
if self._dataset_kind == _DatasetKind.Iterable and \
self._IterableDataset_len_called is not None and \
self._num_yielded > self._IterableDataset_len_called:
warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
"samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
self._num_yielded)
if self._num_workers > 0:
warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
"IterableDataset replica at each worker. Please see "
"https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
warnings.warn(warn_msg)
return data
Then the result comes:
4 | False | time: 7.440187454223633
16 | False | time: 3.06815505027771
64 | False | time: 1.857062816619873
256 | False | time: 1.5879993438720703
512 | False | time: 1.5435562133789062
4 | True | time: 7.606332540512085
16 | True | time: 3.2116823196411133
64 | True | time: 2.0811235904693604
256 | True | time: 1.7673640251159668
512 | True | time: 1.7190146446228027
Based on the test results, it’s evident that including torch.autograd.profiler.record_function(self._profile_name)
within the __next__
method of _BaseDataLoaderIter
has a notable impact on iteration time, leading to slower performance compared to running the iteration without this profiling annotation. This overhead can affect the efficiency of DataLoader operations, particularly in scenarios with large datasets or frequent iterations.
I observed that torch.autograd.profiler.record_function
is specifically utilized within the __next__
method of _BaseDataLoaderIter
in torch/utils/data/dataloader.py
. I’m curious about why profiling is necessary specifically in this context. In the rest of the DataLoader code, there are no any other calls to profiler. Could anyone please explain the rationale behind this choice?
Additionally, I’m wondering if there’s a possibility to either remove this line of code, or introduce a toggle switch to control whether profiling is enabled. This would provide users with more flexibility in managing the overhead introduced by profiling, especially in scenarios where fine-grained profiling is not required for every iteration.
These suggestions aim to optimize the performance and usability of PyTorch’s DataLoader module, particularly in environments where minimizing overhead is crucial for efficient data loading operations. Thank you for considering.
Versions
PyTorch version: 1.13.0+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.28.1
Libc version: glibc-2.31
Python version: 3.8.18 (default, Sep 11 2023, 13:40:15) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-1050-aws-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 535.104.12
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 16
On-line CPU(s) list: 0-15
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 1
NUMA node(s): 1
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz
Stepping: 7
CPU MHz: 2499.996
BogoMIPS: 4999.99
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 256 KiB
L1i cache: 256 KiB
L2 cache: 8 MiB
L3 cache: 35.8 MiB
NUMA node0 CPU(s): 0-15
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed: Vulnerable
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke avx512_vnni
Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.4
[pip3] torch==1.13.0+cu117
[pip3] torchdata==0.5.0
[pip3] torcheval==0.0.7
[pip3] torchmetrics==1.2.1
[conda] numpy 1.24.4 pypi_0 pypi
[conda] torch 1.13.0+cu117 pypi_0 pypi
[conda] torchdata 0.5.0 pypi_0 pypi
[conda] torcheval 0.0.7 pypi_0 pypi
[conda] torchmetrics 1.2.1 pypi_0 pypi