[DataLoader Problem] Problem arises when shuffle = True

From the second epoch onwards they are different.

I have raise the issue here : https://github.com/pytorch/pytorch/issues/20717

I had a look at the github issue and the relevant files.
It is a tricky issue and it is caused by the line that updates the rng state (unnecessarily?): https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataloader.py#L437

I see there are 2 workarounds.

  1. The Dataloader code can be fixed by placing the base seed calculation inside the if loop. (https://github.com/pytorch/pytorch/pull/20749)

  2. Wrap your training code with get_rng_state(), set_rng_state() function calls, as below:

prev_rng_state = torch.get_rng_state()  # get previous rng state

for ep_num in range(3):
    print("==================================",ep_num+1,"========================")

    torch.set_rng_state(prev_rng_state) # set rng state
    for batch,(X_train,y_train,weights) in enumerate(train_iterator):
        if batch==0:
            print("15 examples of train")
            print(X_train[0:15, 0])
            
    prev_rng_state = torch.get_rng_state() # save rng state
    
    for batch,(X_val,y_val) in enumerate(val_iterator):
        if batch==0:
            print("15 examples of validation")
            print(X_val[0:15,0])
    
    for batch,(X_test,y_test) in enumerate(test_iterator):
        if batch==0:
            print("15 examples of test")
            print(X_test[0:15,0])
1 Like

Thanks a lot Arul. Now I am getting the same outputs. A small change to your code if you want both the pipelines to get the same output :

prev_rng_state = torch.get_rng_state()  # get previous rng state

for ep_num in range(10):
    print("==================================",ep_num+1,"========================")

    torch.set_rng_state(prev_rng_state) # set rng state
    for batch,(X_train,y_train,weights) in enumerate(train_iterator):
        if batch==0:
            print("15 examples of train")
            print(X_train[0:15, 0])
             
    for batch,(X_val,y_val) in enumerate(val_iterator):
        if batch==0:
            print("15 examples of validation")
            print(X_val[0:15,0])
    
    prev_rng_state = torch.get_rng_state() # save rng state
    
    for batch,(X_test,y_test) in enumerate(test_iterator):
        if batch==0:
            print("15 examples of test")
            print(X_test[0:15,0])