The context here is a scientific machine learning problem in which training data (which happens to be sequences) is simulated from an underlying ground-truth scientific model. A key wrinkle is that it’s important that the model sees simulations of different sizes during training. So a minibatch of data has shape [batch_size, seq_length, seq_depth], where seq_depth is always the same for a given problem, but seq_length needs to change from one minibatch to the next.
Because the simulation of the data is itself nontrivial, I want to use a Dataloader with multiple workers during training. The issue I’m running into is that the workers don’t “see” a change I make to a Dataset attribute that varies the sequence length during training.
Here’s a minimal reproducible example where seq_depth=2 to illustrate the issue.
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
def my_gen_sample(n):
return torch.randn(n,2)
class TestDataset(Dataset):
def __init__(self, n_sims, gen_sample):
self.n_sims = n_sims
self.gen_sample = gen_sample
self.seq_length = 1 # attribute to be modified after each minibatch
def __len__(self):
return self.n_sims
def set_seq_length(self, n: int):
self.seq_length = n
def __getitem__(self, idx):
x = self.gen_sample(self.seq_length)
return x
This works as expected when num_workers=0. But when num_workers > 0, the changes to ds.seq_length after each minibatch are no longer reflected.
Code for num_workers=0:
ds = TestDataset(50*64, my_gen_sample)
dl = DataLoader(ds, batch_size=64, num_workers=0)
ds.set_seq_length(33)
lengths = [None]*ds.__len__()
for i, x_b in enumerate(dl):
seq_length = np.random.randint(low=5, high=50, size=1)[0]
ds.set_seq_length(seq_length)
lengths[i] = x_b.shape[1]
Now compared with num_workers=2:
dl = DataLoader(ds, batch_size=64, num_workers=2)
ds.set_seq_length(33)
lengths2 = [None]*ds.__len__()
for i, x_b in enumerate(dl):
seq_length = np.random.randint(low=5, high=50, size=1)[0]
ds.set_seq_length(seq_length)
lengths2[i] = x_b.shape[1]
plt.plot(lengths)
plt.plot(lengths2)
I would be grateful to any suggestions for how to get the num_workers = 0 behavior with multiple workers. Thank you!