From the second epoch onwards they are different.
I had a look at the github issue and the relevant files.
It is a tricky issue and it is caused by the line that updates the rng
state (unnecessarily?): https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataloader.py#L437
I see there are 2 workarounds.
-
The
Dataloader
code can be fixed by placing the base seed calculation inside the if loop. (https://github.com/pytorch/pytorch/pull/20749) -
Wrap your training code with
get_rng_state(), set_rng_state()
function calls, as below:
prev_rng_state = torch.get_rng_state() # get previous rng state
for ep_num in range(3):
print("==================================",ep_num+1,"========================")
torch.set_rng_state(prev_rng_state) # set rng state
for batch,(X_train,y_train,weights) in enumerate(train_iterator):
if batch==0:
print("15 examples of train")
print(X_train[0:15, 0])
prev_rng_state = torch.get_rng_state() # save rng state
for batch,(X_val,y_val) in enumerate(val_iterator):
if batch==0:
print("15 examples of validation")
print(X_val[0:15,0])
for batch,(X_test,y_test) in enumerate(test_iterator):
if batch==0:
print("15 examples of test")
print(X_test[0:15,0])
1 Like
Thanks a lot Arul. Now I am getting the same outputs. A small change to your code if you want both the pipelines to get the same output :
prev_rng_state = torch.get_rng_state() # get previous rng state
for ep_num in range(10):
print("==================================",ep_num+1,"========================")
torch.set_rng_state(prev_rng_state) # set rng state
for batch,(X_train,y_train,weights) in enumerate(train_iterator):
if batch==0:
print("15 examples of train")
print(X_train[0:15, 0])
for batch,(X_val,y_val) in enumerate(val_iterator):
if batch==0:
print("15 examples of validation")
print(X_val[0:15,0])
prev_rng_state = torch.get_rng_state() # save rng state
for batch,(X_test,y_test) in enumerate(test_iterator):
if batch==0:
print("15 examples of test")
print(X_test[0:15,0])