Data loading is very slow when using collate_fn with custom object instances

Hello, I run into a problem with data loading when constructing batches that have some custom objects in a batch. I can see that something is happening under the hood to object instances after collate_fn is executed. Can someone explain what is going on and how to overcome this? I would like to make a batch containing custom object instances as well as tensors

from time import perf_counter

import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

torch.random.manual_seed(0)


class SomeObject:
    def __init__(
        self,
    ):
        self.data_torch = torch.rand(100000).numpy()


class Dataset:
    def __init__(self, size: int, numpy: bool):
        self.data = torch.rand(size, 2)
        self.object = [SomeObject() for _ in range(size)]
        if numpy:
            self.data = self.data.numpy()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        d = self.data[i]
        return {
            "feature": d[0],
            "target": d[1],
            "object": self.object[i],
        }

    @staticmethod
    def collate_fn(batch, remove_object):
        if remove_object:
            [x.pop("object", None) for x in batch]
        return batch


SIZE = 1024 * 10
BATCH_SIZE = 2048

print(f"\ntorch version: {torch.__version__}")

for remove_object in [True, False]:
    for pin_memory in [False, True]:
        for numpy in [True]:
            dataset = Dataset(size=SIZE, numpy=numpy)

            times = []

            print("\n" + "-" * 20 + f"\n{numpy=} {pin_memory=} {remove_object=}\n")

            for workers in range(0, 13, 2):
                loader = DataLoader(
                    dataset=dataset,
                    batch_size=BATCH_SIZE,
                    num_workers=workers,
                    collate_fn=lambda batch: Dataset.collate_fn(batch, remove_object),
                    pin_memory=pin_memory,
                )

                t = perf_counter()
                for x in loader:
                    continue
                t = perf_counter() - t
                times.append((workers, t))
                print(f"workers={workers}: {t:.2f}s")

Results:

torch version: 2.2.2

--------------------
numpy=True pin_memory=False remove_object=True

workers=0: 0.01s
workers=2: 0.27s
workers=4: 0.39s
workers=6: 0.50s
workers=8: 0.65s
workers=10: 0.76s
workers=12: 0.89s

--------------------
numpy=True pin_memory=True remove_object=True

workers=0: 0.79s
workers=2: 0.50s
workers=4: 0.72s
workers=6: 0.93s
workers=8: 1.12s
workers=10: 1.34s
workers=12: 1.56s

--------------------
numpy=True pin_memory=False remove_object=False

workers=0: 0.01s
workers=2: 6.97s
workers=4: 7.26s
workers=6: 7.33s
workers=8: 7.58s
workers=10: 7.73s
workers=12: 7.93s

--------------------
numpy=True pin_memory=True remove_object=False

workers=0: 0.04s
workers=2: 5.87s
workers=4: 6.37s
workers=6: 6.82s
workers=8: 7.47s
workers=10: 8.07s
workers=12: 8.52s

As you can see, when we remove_object=True in collate_fn, we are getting much faster loading compared to remove_object=False. Also, the problem does not appear when num_workers=0

What is going on between collate_fn and next(iter(loader)) that causes this performance hit?

I got a reply here: Slow DataLoader in new version when num_workers>0 / objects in collate_fn slow down batching · Issue #123439 · pytorch/pytorch · GitHub

But if someone has another solution and a more detailed explanation, I would be glad to hear it