Thanks for the code.
I can reproduce this behavior using two simple DataLoaders
:
X_train = torch.arange(10).float().view(-1, 1)
y_train = torch.arange(10).float().view(-1, 1) + 0.1
train_dataset = dataset_class(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
X_val = torch.arange(10, 20).float().view(-1, 1)
y_val = torch.arange(10, 20).float().view(-1, 1) + 0.1
val_dataset = dataset_class(X_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
seed = 2809
print('Seeding with {}'.format(seed))
torch.manual_seed(seed)
print('Training loop only')
for epoch in range(2):
for i, (x, y) in enumerate(train_loader):
print('Iter{}, X_train: {}'.format(i, x))
print('='*10)
seed = 2809
print('Seeding with {}'.format(seed))
torch.manual_seed(seed)
print('Training loop only')
for epoch in range(2):
for i, (x, y) in enumerate(train_loader):
print('Iter{}, X_train: {}'.format(i, x))
print('='*10)
seed = 2809
print('Seeding with {}'.format(seed))
torch.manual_seed(seed)
print('Adding validation loop')
for epoch in range(2):
for i, (x, y) in enumerate(train_loader):
print('Iter{}, X_train: {}'.format(i, x))
for j, (x_v, y_v) in enumerate(val_loader):
print('ValIter{}, X_val: {}'.format(j, x_v))
print('='*10)
If you execute the code (with your dataset_class
definition), you’ll see that the train_loader
batches are not the same for the second epoch, if the val_loader
was executed without shuffling.
My best guess is, that the _BaseDataLoaderIter
calls into the PRNG in this line of code, which would be needed to seed each worker here.
A workaround would be to create a torch.Generator
manually and pass it to your train_loader
, so that PyTorch uses it for the _base_seed
creation:
gen = torch.Generator()
X_train = torch.arange(10).float().view(-1, 1)
y_train = torch.arange(10).float().view(-1, 1) + 0.1
train_dataset = dataset_class(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, generator=gen)
X_val = torch.arange(10, 20).float().view(-1, 1)
y_val = torch.arange(10, 20).float().view(-1, 1) + 0.1
val_dataset = dataset_class(X_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
seed = 2809
print('Seeding with {}'.format(seed))
#torch.manual_seed(seed)
gen.manual_seed(seed)
print('Training loop only')
for epoch in range(2):
for i, (x, y) in enumerate(train_loader):
print('Iter{}, X_train: {}'.format(i, x))
print('='*10)
seed = 2809
print('Seeding with {}'.format(seed))
#torch.manual_seed(seed)
gen.manual_seed(seed)
print('Training loop only')
for epoch in range(2):
for i, (x, y) in enumerate(train_loader):
print('Iter{}, X_train: {}'.format(i, x))
print('='*10)
seed = 2809
print('Seeding with {}'.format(seed))
#torch.manual_seed(seed)
gen.manual_seed(seed)
print('Adding validation loop')
for epoch in range(2):
for i, (x, y) in enumerate(train_loader):
print('Iter{}, X_train: {}'.format(i, x))
for j, (x_v, y_v) in enumerate(val_loader):
print('ValIter{}, X_val: {}'.format(j, x_v))
print('='*10)