DataLoader does not access updated Dataset attribute when num_workers > 0

The context here is a scientific machine learning problem in which training data (which happens to be sequences) is simulated from an underlying ground-truth scientific model. A key wrinkle is that it’s important that the model sees simulations of different sizes during training. So a minibatch of data has shape [batch_size, seq_length, seq_depth], where seq_depth is always the same for a given problem, but seq_length needs to change from one minibatch to the next.

Because the simulation of the data is itself nontrivial, I want to use a Dataloader with multiple workers during training. The issue I’m running into is that the workers don’t “see” a change I make to a Dataset attribute that varies the sequence length during training.

Here’s a minimal reproducible example where seq_depth=2 to illustrate the issue.

import torch
import numpy as np 
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset

def my_gen_sample(n):
    return torch.randn(n,2)

class TestDataset(Dataset):
    def __init__(self, n_sims, gen_sample):
        self.n_sims = n_sims
        self.gen_sample = gen_sample
        self.seq_length = 1  # attribute to be modified after each minibatch

    def __len__(self):
        return self.n_sims

    def set_seq_length(self, n: int):
        self.seq_length = n

    def __getitem__(self, idx):
        x = self.gen_sample(self.seq_length)
        return x

This works as expected when num_workers=0. But when num_workers > 0, the changes to ds.seq_length after each minibatch are no longer reflected.

Code for num_workers=0:

ds = TestDataset(50*64, my_gen_sample)
dl = DataLoader(ds, batch_size=64, num_workers=0)

ds.set_seq_length(33)
lengths = [None]*ds.__len__()
for i, x_b in enumerate(dl):
    seq_length = np.random.randint(low=5, high=50, size=1)[0]
    ds.set_seq_length(seq_length)
    lengths[i] = x_b.shape[1]

Now compared with num_workers=2:

dl = DataLoader(ds, batch_size=64, num_workers=2)

ds.set_seq_length(33)
lengths2 = [None]*ds.__len__()
for i, x_b in enumerate(dl):
    seq_length = np.random.randint(low=5, high=50, size=1)[0]
    ds.set_seq_length(seq_length)
    lengths2[i] = x_b.shape[1]

plt.plot(lengths)
plt.plot(lengths2)

image

I would be grateful to any suggestions for how to get the num_workers = 0 behavior with multiple workers. Thank you!

The issue is expected since each worker is a separate process working on a copy of the Dataset.
You could try to share a data structure between all workers and manipulate it from the main thread e.g. via:

class TestDataset(Dataset):
    def __init__(self, n_sims, gen_sample, num_workers=0):
        self.n_sims = n_sims
        self.gen_sample = gen_sample
        shared_array_base = mp.Array(ctypes.c_int64, 1)
        shared_array = np.ctypeslib.as_array(shared_array_base.get_obj())
        self.seq_length = torch.from_numpy(shared_array)

but you would run into race conditions since each worker loads and processes its batch independently.

Based on this description it sounds as if you could know the sequence length in advance for each batch. If so you might be able to write a custom BatchSampler instead which allows you to pass the indices and sequence length of the entire batch to the __getitem__ as seen in this post.

Something like this might work for you:

import torch
import numpy as np 
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset


def my_gen_sample(n):
    return torch.randn(n,2)


class TestDataset(Dataset):
    def __init__(self, n_sims, gen_sample, num_workers=0):
        self.n_sims = n_sims
        self.gen_sample = gen_sample
        self.seq_length = 1  # attribute to be modified after each minibatch

    def __len__(self):
        return self.n_sims

    def set_seq_length(self, n: int):
        self.seq_length = n

    def __getitem__(self, idx):
        x = self.gen_sample(self.seq_length)
        return x
    
    
ds = TestDataset(50*64, my_gen_sample)
dl = DataLoader(ds, batch_size=64, num_workers=0)

ds.set_seq_length(33)
lengths = [None] * len(dl)
for i, x_b in enumerate(dl):
    seq_length = np.random.randint(low=5, high=50, size=1)[0]
    ds.set_seq_length(seq_length)
    lengths[i] = x_b.shape[1]


class MyBatchSampler(torch.utils.data.sampler.Sampler):
    def __init__(self, data_indices, seq_lengths, batch_size):
        self.data_indices = data_indices
        self.seq_lengths = seq_lengths
        self.batch_size = batch_size

    def __iter__(self):
        data_indices_iter = iter(self.data_indices)
        seq_lengths_iter = iter(self.seq_lengths)
        while True:
            try:
                batch = [(next(data_indices_iter), next(seq_lengths_iter)) for _ in range(self.batch_size)]
                yield batch
            except StopIteration:
                break
        
    def __len__(self) -> int:
        return len(self.data_indices) // self.batch_size  # type: ignore[arg-type]


class MyTestDataset(Dataset):
    def __init__(self, n_sims, gen_sample, num_workers=0):
        self.n_sims = n_sims
        self.gen_sample = gen_sample

    def __len__(self):
        return self.n_sims

    def __getitem__(self, idx):
        data_indices, seq_lengths = list(zip(*idx))
        x = []
        for data_index, seq_len in zip(data_indices, seq_lengths):
            x.append(self.gen_sample(seq_len))
        x = torch.stack(x)
        return x


ds = MyTestDataset(50*64, my_gen_sample)    
batch_size = 64
data_indices = torch.arange(len(ds))
seq_lengths = torch.tensor(lengths[:50]).repeat_interleave(batch_size)
assert len(data_indices) == len(seq_lengths)
sampler = MyBatchSampler(data_indices, seq_lengths, batch_size)

dl = DataLoader(ds, sampler=sampler, num_workers=2)

lengths2 = [None] * len(dl)
for i, x_b in enumerate(dl):
    x_b = x_b.squeeze(0)
    lengths2[i] = x_b.shape[1]

plt.plot(lengths)
plt.plot(lengths2)

print((torch.tensor(lengths) == torch.tensor(lengths2)).all())
# tensor(True)

Thank you so much! This seems like a perfect solution. I had wondered about a solution where the Dataset itself stored the lengths in your MyBatchSampler.seq_lengths(), so that a copy of which idx’s went with which sequence lengths was copied to all workers in a manner that ensure consistent seq_length within a single batch. But that required the Dataset to know about what batch size to expect during training. Your solution seems to encapsulate those sequence lengths more straightforwardly.

1 Like

Great to hear the approach might work for you.
Just wanted to let you know that the num_workers argument in TestDataset is dead code from a previous test, so ignore it. :wink: