Hello everyone, I am trying to use PyTorch to save model checkpoints, optimizer states, and random states for ‘resume training’. However, I found that if there is a dropout in the model, it cannot reproduce the same output as the original model after loading checkpoints, random seeds, and random states.
Below is a minimum test unit. You can first set ‘load_checkpoint=1’ and run it once to save the checkpoint, then set it to 0 and run it again.
As you can see, I have already set the same random seeds (including torch, torch.cuda, numpy, and random) and optimizer states before starting the experiment.
In theory, if the states are consistent, the output values of the model should be the same in both normal training and checkpoint loading modes (refer to the print()
on line 118). That’s exactly the case, in the parameter use_dropout=0
(i.e. the model does not use Dropout), the final output values are the same. But when use_dropout=1
, the output values in the two modes are different.
The reason for this issue is that most models have Dropouts (such as BERT). However, due to limited resources, we may not be able to complete the training at once. And due to the dropout, the performance changing is astonishing between normal training and resuming training modes (in the task I am currently working on).
How should this be resolved? One solution I can think of is to save the random state of the dropout, but it seems infeasible.
My current version of Pytorch is 1.13.0+cu116
.
import os, argparse, pickle, sys
import random, torch
import numpy as np
import torch.nn as nn
from tqdm import tqdm
SEED = 1
parser = argparse.ArgumentParser(description='KGTL')
parser.add_argument('--cuda', nargs='+', type=int, default=[0,1,2,3], help="cuda")
parser.add_argument('--device', type=int, default=0, help='-1 is CPU')
parser.add_argument('--load_checkpoint', type=int, default=1, help="whether to load checkpoint or not, 0 is not")
parser.add_argument('--use_dropout', type=int, default=1, help="whether to use Dropout on model or not, 0 is not")
params = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in params.cuda])
os.environ['WORLD_SIZE'] = str(len(params.cuda))
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# set the state of random
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.backends.cudnn.deterministic = True
class M(nn.Module):
def __init__(self) -> None:
super().__init__()
self.E = nn.Embedding(5, 128)
self.linear = nn.Linear(128, 256)
self.tanh = nn.Tanh()
self.dropout = nn.Dropout(0.2)
def forward(self, data):
if params.use_dropout != 0:
return self.dropout(
self.tanh(self.linear(self.E(data)))
)
else:
return self.tanh(self.linear(self.E(data)))
batch_size = 3
save_at_epoch = 6
device = torch.device(f"cuda:{params.device}" if params.device != -1 else "cpu")
load_checkpoint = params.load_checkpoint == 1
dir_path = "./test_checkpoints"
os.makedirs(dir_path, exist_ok=True)
ELEs = [
"checkpoint.pt", "",
"random_state.pkl", "torch_random_state.pkl", "torch_cuda_random_state.pkl", "numpy_random_state.pkl",
"optimizer.pkl"
]
if load_checkpoint:
with open(os.path.join(dir_path, ELEs[2]), 'rb') as handle:
random.setstate(pickle.load(handle))
with open(os.path.join(dir_path, ELEs[3]), 'rb') as handle:
torch.set_rng_state(pickle.load(handle))
with open(os.path.join(dir_path, ELEs[4]), 'rb') as handle:
torch.cuda.set_rng_state_all(pickle.load(handle))
with open(os.path.join(dir_path, ELEs[5]), 'rb') as handle:
np.random.set_state(pickle.load(handle))
model = M().to(device)
opt = torch.optim.Adam(model.parameters(), lr=2e-4)
loss_func = nn.CrossEntropyLoss()
if load_checkpoint:
print("loading the checkpoint and states...")
model.load_state_dict(
torch.load(os.path.join(dir_path, ELEs[0]), map_location=device)
)
with open(os.path.join(dir_path, ELEs[6]), 'rb') as handle:
opt.load_state_dict(pickle.load(handle))
training_data = [1,2,3,3,2,1,1,2,3,3,2,1,1,2,3,3,2,1,1,2,3,3,2,1]
training_labels = [1 for _ in range(len(training_data))]
for epoch in tqdm(range(8)):
if epoch == save_at_epoch:
if not load_checkpoint:
print("saving the checkpoint and states...")
# save the random state of random
with open(os.path.join(dir_path, ELEs[2]), 'wb') as handle:
pickle.dump(random.getstate(), handle, protocol=pickle.HIGHEST_PROTOCOL)
# save the random state of torch
with open(os.path.join(dir_path, ELEs[3]), 'wb') as handle:
pickle.dump(torch.get_rng_state(), handle, protocol=pickle.HIGHEST_PROTOCOL)
# save the random state of cuda
with open(os.path.join(dir_path, ELEs[4]), 'wb') as handle:
pickle.dump(torch.cuda.get_rng_state_all(), handle, protocol=pickle.HIGHEST_PROTOCOL)
# save the random state of numpy
with open(os.path.join(dir_path, ELEs[5]), 'wb') as handle:
pickle.dump(np.random.get_state(), handle, protocol=pickle.HIGHEST_PROTOCOL)
# save the state of optimizer
with open(os.path.join(dir_path, ELEs[6]), 'wb') as handle:
pickle.dump(opt.state_dict(), handle, protocol=pickle.HIGHEST_PROTOCOL)
# save the checkpoint of model
torch.save(model.state_dict(), os.path.join(dir_path, ELEs[0]))
else:
if load_checkpoint:
continue
for index in range(0, len(training_data), batch_size):
batch = training_data[index:index+batch_size]
labels = training_labels[index:index+batch_size]
opt.zero_grad()
loss = loss_func(
model(torch.LongTensor(batch).to(device)),
torch.LongTensor(labels).to(device)
)
if epoch == save_at_epoch:
print(loss.item())
sys.exit()
loss.backward()
opt.step()