Does a dataloader change random state even when shuffle argument is False?

Thanks for the code.
I can reproduce this behavior using two simple DataLoaders:

X_train = torch.arange(10).float().view(-1, 1)
y_train = torch.arange(10).float().view(-1, 1) + 0.1
train_dataset = dataset_class(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

X_val = torch.arange(10, 20).float().view(-1, 1)
y_val = torch.arange(10, 20).float().view(-1, 1) + 0.1
val_dataset = dataset_class(X_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)


seed = 2809
print('Seeding with {}'.format(seed))
torch.manual_seed(seed)
print('Training loop only')
for epoch in range(2):
    for i, (x, y) in enumerate(train_loader):
        print('Iter{}, X_train: {}'.format(i, x))
    print('='*10)

seed = 2809
print('Seeding with {}'.format(seed))
torch.manual_seed(seed)
print('Training loop only')
for epoch in range(2):
    for i, (x, y) in enumerate(train_loader):
        print('Iter{}, X_train: {}'.format(i, x))
    print('='*10)

seed = 2809
print('Seeding with {}'.format(seed))
torch.manual_seed(seed)
print('Adding validation loop')
for epoch in range(2):
    for i, (x, y) in enumerate(train_loader):
        print('Iter{}, X_train: {}'.format(i, x))
    for j, (x_v, y_v) in enumerate(val_loader):
        print('ValIter{}, X_val: {}'.format(j, x_v))
    print('='*10)

If you execute the code (with your dataset_class definition), you’ll see that the train_loader batches are not the same for the second epoch, if the val_loader was executed without shuffling.

My best guess is, that the _BaseDataLoaderIter calls into the PRNG in this line of code, which would be needed to seed each worker here.

A workaround would be to create a torch.Generator manually and pass it to your train_loader, so that PyTorch uses it for the _base_seed creation:

gen = torch.Generator()

X_train = torch.arange(10).float().view(-1, 1)
y_train = torch.arange(10).float().view(-1, 1) + 0.1
train_dataset = dataset_class(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, generator=gen)

X_val = torch.arange(10, 20).float().view(-1, 1)
y_val = torch.arange(10, 20).float().view(-1, 1) + 0.1
val_dataset = dataset_class(X_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)


seed = 2809
print('Seeding with {}'.format(seed))
#torch.manual_seed(seed)
gen.manual_seed(seed)
print('Training loop only')
for epoch in range(2):
    for i, (x, y) in enumerate(train_loader):
        print('Iter{}, X_train: {}'.format(i, x))
    print('='*10)

seed = 2809
print('Seeding with {}'.format(seed))
#torch.manual_seed(seed)
gen.manual_seed(seed)
print('Training loop only')
for epoch in range(2):
    for i, (x, y) in enumerate(train_loader):
        print('Iter{}, X_train: {}'.format(i, x))
    print('='*10)

seed = 2809
print('Seeding with {}'.format(seed))
#torch.manual_seed(seed)
gen.manual_seed(seed)
print('Adding validation loop')
for epoch in range(2):
    for i, (x, y) in enumerate(train_loader):
        print('Iter{}, X_train: {}'.format(i, x))
    for j, (x_v, y_v) in enumerate(val_loader):
        print('ValIter{}, X_val: {}'.format(j, x_v))
    print('='*10)
2 Likes