Unexpected DataLoader hanging when torch.set_num_threads in worker init

Torch’s _worker_loop defaults the num_threads to 1, but in my case I do a lot of copying data off disk using memory mapped tensors, so it’s useful to have num_threads>1. If i try to force the number of threads to be greater than 1

def worker_init(worker_id):
    torch.set_num_threads(96 // 4) 

My DataLoader does not return a batch. Is there any way around this? Is there any way to configure the number of threads used by worker_fn of a torch Dataloader. I’ve included code to repro below

import torch
import time
from tqdm import tqdm
from tensordict import TensorDict
from torch.utils.data import Dataset, DataLoader, BatchSampler, SequentialSampler

write_new = True
if write_new:
    B = 512_000
    D = 128 
    flat_fields = {
        f"feat_{i}": torch.randn(B, D) for i in range(20)
    }
    flat_fields.update({
        f"cat_{i}": torch.randint(0, 100, (B,)) for i in range(5)
    })
    flat_fields.update({
        f"mask_{i}": torch.randint(0, 2, (B,), dtype=torch.bool) for i in range(5)
    })

    nested1 = TensorDict({
        "subfeat_1": torch.randn(B, D),
        "subfeat_2": torch.randn(B, D)
    }, batch_size=[B])
    
    nested2 = TensorDict({
        "subcat_1": torch.randint(0, 10, (B,)),
        "submask_1": torch.randint(0, 2, (B,), dtype=torch.bool)
    }, batch_size=[B])
   
    td = TensorDict(flat_fields, batch_size=[B])
    td["nested1"] = nested1
    td["nested2"] = nested2
    
    td.memmap_("/tmp/mmap_tensordict_large", num_threads=8)

td_mmap = TensorDict.load_memmap("/tmp/mmap_tensordict_large")
class MemmapTensorDictDataset(Dataset):
    def __init__(self, td):
        # torch.set_num_threads(1)
        self.td = td

    def __getitem__(self, idx):
        return self.td[idx]

    def __len__(self):
        return self.td.batch_size[0]

batch_size = 8_000
dataset = MemmapTensorDictDataset(td_mmap)
inner_sampler = SequentialSampler(range(len(dataset)))
sampler = BatchSampler(inner_sampler, batch_size=batch_size, drop_last=False)

def benchmark(loader, name, epochs=1):
    t0 = time.time()
    for _ in tqdm(range(epochs)):
        for _ in loader:
            break
    elapsed = time.time() - t0
    print(f"{name:15s}: {elapsed:.2f}s")

batch_size = 4096

loader0 = DataLoader(
    dataset,
    batch_size=batch_size,
    num_workers=0,
    pin_memory=True,
    sampler=sampler,
    collate_fn=lambda x: x[0],
)


def worker_init(worker_id):
    torch.set_num_threads(96//4)  

loader4 = DataLoader(
    dataset,
    batch_size=batch_size,
    num_workers=4,
    pin_memory=False,
    sampler=sampler,
    collate_fn=lambda x: x[0],
    prefetch_factor=2,
    persistent_workers=True,
    worker_init_fn=worker_init,
)

benchmark(loader0, "num_workers=0", epochs=10)
benchmark(loader4, "num_workers=4", epochs=10) # THIS WILL HANG when `worker_init` us used.. and if you don't use it's much slower, in part because the num_threads=1