Hello, I run into a problem with data loading when constructing batches that have some custom objects in a batch. I can see that something is happening under the hood to object instances after collate_fn
is executed. Can someone explain what is going on and how to overcome this? I would like to make a batch containing custom object instances as well as tensors
from time import perf_counter
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
torch.random.manual_seed(0)
class SomeObject:
def __init__(
self,
):
self.data_torch = torch.rand(100000).numpy()
class Dataset:
def __init__(self, size: int, numpy: bool):
self.data = torch.rand(size, 2)
self.object = [SomeObject() for _ in range(size)]
if numpy:
self.data = self.data.numpy()
def __len__(self):
return len(self.data)
def __getitem__(self, i):
d = self.data[i]
return {
"feature": d[0],
"target": d[1],
"object": self.object[i],
}
@staticmethod
def collate_fn(batch, remove_object):
if remove_object:
[x.pop("object", None) for x in batch]
return batch
SIZE = 1024 * 10
BATCH_SIZE = 2048
print(f"\ntorch version: {torch.__version__}")
for remove_object in [True, False]:
for pin_memory in [False, True]:
for numpy in [True]:
dataset = Dataset(size=SIZE, numpy=numpy)
times = []
print("\n" + "-" * 20 + f"\n{numpy=} {pin_memory=} {remove_object=}\n")
for workers in range(0, 13, 2):
loader = DataLoader(
dataset=dataset,
batch_size=BATCH_SIZE,
num_workers=workers,
collate_fn=lambda batch: Dataset.collate_fn(batch, remove_object),
pin_memory=pin_memory,
)
t = perf_counter()
for x in loader:
continue
t = perf_counter() - t
times.append((workers, t))
print(f"workers={workers}: {t:.2f}s")
Results:
torch version: 2.2.2
--------------------
numpy=True pin_memory=False remove_object=True
workers=0: 0.01s
workers=2: 0.27s
workers=4: 0.39s
workers=6: 0.50s
workers=8: 0.65s
workers=10: 0.76s
workers=12: 0.89s
--------------------
numpy=True pin_memory=True remove_object=True
workers=0: 0.79s
workers=2: 0.50s
workers=4: 0.72s
workers=6: 0.93s
workers=8: 1.12s
workers=10: 1.34s
workers=12: 1.56s
--------------------
numpy=True pin_memory=False remove_object=False
workers=0: 0.01s
workers=2: 6.97s
workers=4: 7.26s
workers=6: 7.33s
workers=8: 7.58s
workers=10: 7.73s
workers=12: 7.93s
--------------------
numpy=True pin_memory=True remove_object=False
workers=0: 0.04s
workers=2: 5.87s
workers=4: 6.37s
workers=6: 6.82s
workers=8: 7.47s
workers=10: 8.07s
workers=12: 8.52s
As you can see, when we remove_object=True
in collate_fn
, we are getting much faster loading compared to remove_object=False
. Also, the problem does not appear when num_workers=0
What is going on between collate_fn
and next(iter(loader))
that causes this performance hit?