The context here is a scientific machine learning problem in which training data (which happens to be sequences) is simulated from an underlying ground-truth scientific model. A key wrinkle is that it’s important that the model sees simulations of different sizes during training. So a minibatch of data has shape [batch_size, seq_length, seq_depth], where seq_depth is always the same for a given problem, but seq_length needs to change from one minibatch to the next.
Because the simulation of the data is itself nontrivial, I want to use a Dataloader with multiple workers during training. The issue I’m running into is that the workers don’t “see” a change I make to a Dataset attribute that varies the sequence length during training.
Here’s a minimal reproducible example where seq_depth=2 to illustrate the issue.
import torch import numpy as np import matplotlib.pyplot as plt from torch.utils.data import DataLoader, Dataset def my_gen_sample(n): return torch.randn(n,2) class TestDataset(Dataset): def __init__(self, n_sims, gen_sample): self.n_sims = n_sims self.gen_sample = gen_sample self.seq_length = 1 # attribute to be modified after each minibatch def __len__(self): return self.n_sims def set_seq_length(self, n: int): self.seq_length = n def __getitem__(self, idx): x = self.gen_sample(self.seq_length) return x
This works as expected when num_workers=0. But when num_workers > 0, the changes to ds.seq_length after each minibatch are no longer reflected.
Code for num_workers=0:
ds = TestDataset(50*64, my_gen_sample) dl = DataLoader(ds, batch_size=64, num_workers=0) ds.set_seq_length(33) lengths = [None]*ds.__len__() for i, x_b in enumerate(dl): seq_length = np.random.randint(low=5, high=50, size=1) ds.set_seq_length(seq_length) lengths[i] = x_b.shape
Now compared with num_workers=2:
dl = DataLoader(ds, batch_size=64, num_workers=2) ds.set_seq_length(33) lengths2 = [None]*ds.__len__() for i, x_b in enumerate(dl): seq_length = np.random.randint(low=5, high=50, size=1) ds.set_seq_length(seq_length) lengths2[i] = x_b.shape plt.plot(lengths) plt.plot(lengths2)
I would be grateful to any suggestions for how to get the num_workers = 0 behavior with multiple workers. Thank you!