The issue is expected since each worker is a separate process working on a copy of the Dataset
.
You could try to share a data structure between all workers and manipulate it from the main thread e.g. via:
class TestDataset(Dataset):
def __init__(self, n_sims, gen_sample, num_workers=0):
self.n_sims = n_sims
self.gen_sample = gen_sample
shared_array_base = mp.Array(ctypes.c_int64, 1)
shared_array = np.ctypeslib.as_array(shared_array_base.get_obj())
self.seq_length = torch.from_numpy(shared_array)
but you would run into race conditions since each worker loads and processes its batch independently.
Based on this description it sounds as if you could know the sequence length in advance for each batch. If so you might be able to write a custom BatchSampler
instead which allows you to pass the indices and sequence length of the entire batch to the __getitem__
as seen in this post.
Something like this might work for you:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
def my_gen_sample(n):
return torch.randn(n,2)
class TestDataset(Dataset):
def __init__(self, n_sims, gen_sample, num_workers=0):
self.n_sims = n_sims
self.gen_sample = gen_sample
self.seq_length = 1 # attribute to be modified after each minibatch
def __len__(self):
return self.n_sims
def set_seq_length(self, n: int):
self.seq_length = n
def __getitem__(self, idx):
x = self.gen_sample(self.seq_length)
return x
ds = TestDataset(50*64, my_gen_sample)
dl = DataLoader(ds, batch_size=64, num_workers=0)
ds.set_seq_length(33)
lengths = [None] * len(dl)
for i, x_b in enumerate(dl):
seq_length = np.random.randint(low=5, high=50, size=1)[0]
ds.set_seq_length(seq_length)
lengths[i] = x_b.shape[1]
class MyBatchSampler(torch.utils.data.sampler.Sampler):
def __init__(self, data_indices, seq_lengths, batch_size):
self.data_indices = data_indices
self.seq_lengths = seq_lengths
self.batch_size = batch_size
def __iter__(self):
data_indices_iter = iter(self.data_indices)
seq_lengths_iter = iter(self.seq_lengths)
while True:
try:
batch = [(next(data_indices_iter), next(seq_lengths_iter)) for _ in range(self.batch_size)]
yield batch
except StopIteration:
break
def __len__(self) -> int:
return len(self.data_indices) // self.batch_size # type: ignore[arg-type]
class MyTestDataset(Dataset):
def __init__(self, n_sims, gen_sample, num_workers=0):
self.n_sims = n_sims
self.gen_sample = gen_sample
def __len__(self):
return self.n_sims
def __getitem__(self, idx):
data_indices, seq_lengths = list(zip(*idx))
x = []
for data_index, seq_len in zip(data_indices, seq_lengths):
x.append(self.gen_sample(seq_len))
x = torch.stack(x)
return x
ds = MyTestDataset(50*64, my_gen_sample)
batch_size = 64
data_indices = torch.arange(len(ds))
seq_lengths = torch.tensor(lengths[:50]).repeat_interleave(batch_size)
assert len(data_indices) == len(seq_lengths)
sampler = MyBatchSampler(data_indices, seq_lengths, batch_size)
dl = DataLoader(ds, sampler=sampler, num_workers=2)
lengths2 = [None] * len(dl)
for i, x_b in enumerate(dl):
x_b = x_b.squeeze(0)
lengths2[i] = x_b.shape[1]
plt.plot(lengths)
plt.plot(lengths2)
print((torch.tensor(lengths) == torch.tensor(lengths2)).all())
# tensor(True)