Does a dataloader change random state even when shuffle argument is False?

Hello.

I am facing an unexpected behaviour in my dataloading scheme. I have two dataloaders, a train_dl and a test_dl. The train_dl provides batches of data with the argument shuffle=True and the test_dl provide batches with the argument shuffle=False.

I evaluate my test metrics each N epochs, i.e each N epochs I loop over test_dl dataset. I have realized that if the value of N changes, then the shuffled batches provided by the train_dl after evaluating the test metrics differ. In other words, suppose my training points are labelled: a,b,c,d,e,f and my batch size is 2. Then if I seed everything correctly I observe the following pattern always:

Epoch 1:
Batch 1: [a,b]
Batch 2: [e,d]
Batch 3: [c,f]

Epoch 2:
Batch 1: [a,f]
Batch 2: [b,d]
Batch 3: [e,c]

However, if in between epoch 1 and epoch 2 I loop over my test_dl then the provided batches at epoch 2 differ or in other words it seems that even shuffle = False for the test_dl, the random state seems to change.

Is this true?

Thanks

Based on the description it seems that the test DataLoader is calling into the PRNG at one point. Since shuffle is set to False, the Dataset might still use it at one point. Could you post the Dataset implementation (for the testset in particular) as well as the test loop?

This is the dataset class:


class dataset_class(torch.utils.data.Dataset):
    ''' General dataset class
        Args:
                X: input -> torch.tensor
                Y: targets -> torch.tensor
    '''

    def __init__(self,X: torch.tensor, Y: torch.tensor):
        super(dataset_class, self).__init__()
        assert X is not None , "Invalid None type"
        assert Y is not None, "Invalid None type"
        self.X = X
        self.Y = Y

    def __len__(self):
        return self.X.size(0)

    def __getitem__(self,idx):
      
        return self.X[idx],self.Y[idx]

where self.X and self.Y are constant for all the runs, i.e its order is not shuffled unless the dataloader does it. This is the main loop:

for b,(x,y) in enumerate(train_loader):

     x,y = x.to(cg.device),y.to(cg.device)
                   
      ## loss
      loss = model.loss(x,y)
      optimizer.zero_grad() # just in case
      loss.backward()
      optimizer.step()
              
      # Performance metrics                      
      if ((ep+1)%self.validate_each) == 0:
           model.performance_metrics(x,y,dataset='train')
 
      with torch.no_grad():

          if ((ep+1)%validate_each) == 0:
                    for b,(x,y) in enumerate(test_loader):
                        x,y = x.to(cg.device),y.to(cg.device)
                        model.performance_metrics(x,y,dataset='test')

Thanks for the code.
I can reproduce this behavior using two simple DataLoaders:

X_train = torch.arange(10).float().view(-1, 1)
y_train = torch.arange(10).float().view(-1, 1) + 0.1
train_dataset = dataset_class(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

X_val = torch.arange(10, 20).float().view(-1, 1)
y_val = torch.arange(10, 20).float().view(-1, 1) + 0.1
val_dataset = dataset_class(X_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)


seed = 2809
print('Seeding with {}'.format(seed))
torch.manual_seed(seed)
print('Training loop only')
for epoch in range(2):
    for i, (x, y) in enumerate(train_loader):
        print('Iter{}, X_train: {}'.format(i, x))
    print('='*10)

seed = 2809
print('Seeding with {}'.format(seed))
torch.manual_seed(seed)
print('Training loop only')
for epoch in range(2):
    for i, (x, y) in enumerate(train_loader):
        print('Iter{}, X_train: {}'.format(i, x))
    print('='*10)

seed = 2809
print('Seeding with {}'.format(seed))
torch.manual_seed(seed)
print('Adding validation loop')
for epoch in range(2):
    for i, (x, y) in enumerate(train_loader):
        print('Iter{}, X_train: {}'.format(i, x))
    for j, (x_v, y_v) in enumerate(val_loader):
        print('ValIter{}, X_val: {}'.format(j, x_v))
    print('='*10)

If you execute the code (with your dataset_class definition), you’ll see that the train_loader batches are not the same for the second epoch, if the val_loader was executed without shuffling.

My best guess is, that the _BaseDataLoaderIter calls into the PRNG in this line of code, which would be needed to seed each worker here.

A workaround would be to create a torch.Generator manually and pass it to your train_loader, so that PyTorch uses it for the _base_seed creation:

gen = torch.Generator()

X_train = torch.arange(10).float().view(-1, 1)
y_train = torch.arange(10).float().view(-1, 1) + 0.1
train_dataset = dataset_class(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, generator=gen)

X_val = torch.arange(10, 20).float().view(-1, 1)
y_val = torch.arange(10, 20).float().view(-1, 1) + 0.1
val_dataset = dataset_class(X_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)


seed = 2809
print('Seeding with {}'.format(seed))
#torch.manual_seed(seed)
gen.manual_seed(seed)
print('Training loop only')
for epoch in range(2):
    for i, (x, y) in enumerate(train_loader):
        print('Iter{}, X_train: {}'.format(i, x))
    print('='*10)

seed = 2809
print('Seeding with {}'.format(seed))
#torch.manual_seed(seed)
gen.manual_seed(seed)
print('Training loop only')
for epoch in range(2):
    for i, (x, y) in enumerate(train_loader):
        print('Iter{}, X_train: {}'.format(i, x))
    print('='*10)

seed = 2809
print('Seeding with {}'.format(seed))
#torch.manual_seed(seed)
gen.manual_seed(seed)
print('Adding validation loop')
for epoch in range(2):
    for i, (x, y) in enumerate(train_loader):
        print('Iter{}, X_train: {}'.format(i, x))
    for j, (x_v, y_v) in enumerate(val_loader):
        print('ValIter{}, X_val: {}'.format(j, x_v))
    print('='*10)
2 Likes

Okei thank you. Normally I dont see a difference in performance due to how batches are shuffled. But for my particular problem I am experimenting very different results just by setting validate_each to different numbers (different numbers implies different number loops over the test dataset hence different order).

Thanks for your solution