Torch’s _worker_loop
defaults the num_threads to 1, but in my case I do a lot of copying data off disk using memory mapped tensors, so it’s useful to have num_threads>1. If i try to force the number of threads to be greater than 1
def worker_init(worker_id):
torch.set_num_threads(96 // 4)
My DataLoader does not return a batch. Is there any way around this? Is there any way to configure the number of threads used by worker_fn of a torch Dataloader. I’ve included code to repro below
import torch
import time
from tqdm import tqdm
from tensordict import TensorDict
from torch.utils.data import Dataset, DataLoader, BatchSampler, SequentialSampler
write_new = True
if write_new:
B = 512_000
D = 128
flat_fields = {
f"feat_{i}": torch.randn(B, D) for i in range(20)
}
flat_fields.update({
f"cat_{i}": torch.randint(0, 100, (B,)) for i in range(5)
})
flat_fields.update({
f"mask_{i}": torch.randint(0, 2, (B,), dtype=torch.bool) for i in range(5)
})
nested1 = TensorDict({
"subfeat_1": torch.randn(B, D),
"subfeat_2": torch.randn(B, D)
}, batch_size=[B])
nested2 = TensorDict({
"subcat_1": torch.randint(0, 10, (B,)),
"submask_1": torch.randint(0, 2, (B,), dtype=torch.bool)
}, batch_size=[B])
td = TensorDict(flat_fields, batch_size=[B])
td["nested1"] = nested1
td["nested2"] = nested2
td.memmap_("/tmp/mmap_tensordict_large", num_threads=8)
td_mmap = TensorDict.load_memmap("/tmp/mmap_tensordict_large")
class MemmapTensorDictDataset(Dataset):
def __init__(self, td):
# torch.set_num_threads(1)
self.td = td
def __getitem__(self, idx):
return self.td[idx]
def __len__(self):
return self.td.batch_size[0]
batch_size = 8_000
dataset = MemmapTensorDictDataset(td_mmap)
inner_sampler = SequentialSampler(range(len(dataset)))
sampler = BatchSampler(inner_sampler, batch_size=batch_size, drop_last=False)
def benchmark(loader, name, epochs=1):
t0 = time.time()
for _ in tqdm(range(epochs)):
for _ in loader:
break
elapsed = time.time() - t0
print(f"{name:15s}: {elapsed:.2f}s")
batch_size = 4096
loader0 = DataLoader(
dataset,
batch_size=batch_size,
num_workers=0,
pin_memory=True,
sampler=sampler,
collate_fn=lambda x: x[0],
)
def worker_init(worker_id):
torch.set_num_threads(96//4)
loader4 = DataLoader(
dataset,
batch_size=batch_size,
num_workers=4,
pin_memory=False,
sampler=sampler,
collate_fn=lambda x: x[0],
prefetch_factor=2,
persistent_workers=True,
worker_init_fn=worker_init,
)
benchmark(loader0, "num_workers=0", epochs=10)
benchmark(loader4, "num_workers=4", epochs=10) # THIS WILL HANG when `worker_init` us used.. and if you don't use it's much slower, in part because the num_threads=1