Deterministically training model with dropout from checkpoint

Hi Everyone!

I’m running into trouble when trying to deterministically resume training a model with dropout from a checkpoint: the training losses change after loading the checkpoint.

I expect I’m missing something simple, but can’t seem to find it.

Minimal code:

import numpy as np
import torch
import torch.nn as nn
import os

def gen_dataset():
    # x is 10d, y is 1d
    x = np.stack([100 * np.random.rand(500) + np.arange(500) for k in range(10)])
    x = x.transpose()
    y = np.arange(500) + 100*np.random.rand(500)
    return x,y

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        self.x = torch.tensor(x).to(torch.float32)
        self.y = torch.tensor(y).to(torch.float32)
        
    def __getitem__(self, index):
        return self.x[index], self.y[index]
    
    def __len__(self):
        return len(self.x)
    
def sse(x,y):
    return sum ((x-y)**2)

def train(epochs, model, optimizer):
    model.train()
    for k in range(epochs):
        train_loss = 0
        
        for (temp_x, temp_y) in Tr_Dataloader:
            temp_x = temp_x.to('cuda')
            
            temp_y = temp_y.unsqueeze(-1)
            temp_y = temp_y.to('cuda')
            
            model_out = model(temp_x)
            loss = sse(model_out, temp_y)
            train_loss += loss[0].item()
            
            loss.backward()
            optimizer.step()
            model.zero_grad()
            
        print('epoch', k, '| train loss', train_loss)
    print('---')
    return train_loss, model, optimizer
    
def save(path, model, optimizer):
    Out_Dict = {}
    Out_Dict['model_state_dict'] = model.state_dict()
    Out_Dict['optimizer_state_dict'] = optimizer.state_dict()
    Out_Dict['Numpy_Random_State'] = np.random.get_state()
    Out_Dict['Torch_Random_State'] = torch.get_rng_state()
    torch.save(Out_Dict, path)
    
def load(path, model, optimizer):
    In_Dict = torch.load(path)
        
    model.load_state_dict(In_Dict['model_state_dict'])
    
    # optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-3) # does not help
    optimizer.load_state_dict(In_Dict['optimizer_state_dict'])

    np.random.set_state(In_Dict['Numpy_Random_State'])
    torch.random.set_rng_state(In_Dict['Torch_Random_State'])
    torch.backends.cudnn.deterministic = True 
    
    return model, optimizer

def prep_model_and_optimizer():
    model = nn.Sequential(
        nn.Dropout(),
        nn.Linear(in_features = 10, out_features = 1)
        )
    model = model.to('cuda')
    optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-3)

    return model, optimizer

# %%
# Prep random seed
np.random.seed(3)
torch.manual_seed(3)
torch.backends.cudnn.deterministic = True 
torch.backends.cudnn.benchmark = False 
torch.cuda.manual_seed(3)

# Prep data
x,y = gen_dataset()
Tr_X = x[:400]
Tr_Y = y[:400]

Tr_Dataset = CustomDataset(Tr_X,Tr_Y)
Tr_Dataloader = torch.utils.data.DataLoader(Tr_Dataset, batch_size = 100, shuffle=False)

# Init. Train | save, train | load, train | Init, load, train
model, optimizer = prep_model_and_optimizer()

E1 = 10
print('Training Epochs:', E1)
loss, model, optimizer = train(E1, model, optimizer)

E2 = 2
print('Saving. Training more epochs:',E2)
save_path = os.path.join(os.getcwd(), 'checkpoint')
save(save_path, model, optimizer)
# Load(save_path, model, optimizer) # Save then Load has no effect
loss, model, optimizer = train(E2, model, optimizer)

print('Loading. Training more epochs:',E2)
model, optimizer = load(save_path, model, optimizer)
loss, model, optimizer = train(E2, model, optimizer)

print('re-initializing. Loading. Training more epochs:',E2)
model, optimizer = prep_model_and_optimizer()
model, optimizer = load(save_path, model, optimizer)
loss, model, optimizer = train(E2, model, optimizer)

Output:

Training Epochs: 10
epoch 0 | train loss 96417973.0
epoch 1 | train loss 96046790.0
epoch 2 | train loss 86500470.25
epoch 3 | train loss 85599624.25
epoch 4 | train loss 90644446.75
epoch 5 | train loss 77562509.5
epoch 6 | train loss 78515326.5
epoch 7 | train loss 71147675.75
epoch 8 | train loss 67646038.75
epoch 9 | train loss 65107301.75
/---
Saving. Training more epochs: 2
epoch 0 | train loss 62031824.75
epoch 1 | train loss 57108137.25
---
Loading. Training more epochs: 2
epoch 0 | train loss 67927928.5
epoch 1 | train loss 69468169.25
---
re-initializing. Loading. Training more epochs: 2
epoch 0 | train loss 66837200.0
epoch 1 | train loss 64987717.75
---

What’s wrong: the lower three pairs of training losses should be identical, but they aren’t.

I’ve checked:

  • Calling save(), then load() does not affect the training loss on subsequent epochs
  • The Torch and NP random states loaded are the ones that are saved
  • Eliminating the Dropout layer gives identical training outputs.
  • Re-initializing the optimizer in Prep_Model_and_Optimizer() doesn’t change the outcome
  • “torch.backends.cudnn.benchmark = False” doesn’t change the outcome

Specific Question:

  1. Any idea what I’m doing wrong?

More general questions:

  1. Do I need to re-initialize an optimizer whenever i create a model to pass the correct model parameters to it?
  2. If I want Train() to be a separate function, do I need to keep passing the model and optimizer back-and-forth, or are those passed by reference?

Thanks!
Platon

edits: clarity and formatting

You’ve got to set your model to evaluation mode with model.eval(). Otherwise it will keep using dropout, which introduces randomness. You should also be running it without gathering gradients, you can do this under a with torch.no_grad: context manager. When you’re ready to train again, make sure to set your model to training mode with model.train().

model.eval()
with torch.no_grad():
  # do stuff with the model
model.train() # go back to training mode when you're done with inference

Note: This is a little nitpicky, but when naming Python functions the standard is to use lower case with words separated by underscores. Someone reading could confuse it for a class, which is usually written with the first letter of each word as capital, and not separated by underscores, KindaLikeThis. See the style guide here.

1 Like

Thank you for your response!

I’m specifically looking to deterministically train the model from a checkpoint.
And yes, for a given checkpoint, the outputs are consistent in evaluation mode.

(I also appreciate the style feedback)

I’ll adjust the phrasing and formatting of my original post.

Found it!

I needed to save and load the cuda RNG state.

#On save:
Out_Dict['CUDA_Random_State'] = torch.cuda.get_rng_state()
# On load:
torch.cuda.set_rng_state(In_Dict['CUDA_Random_State'])

Also:

  • initializing torch.cuda.manual_seed() manually is unnecessary - it is set by torch.manual_seed()
  • If you re-initialize a model, you do need to assign it’s parameters to the optimizer.

edits: clarity, additional information

2 Likes

Oops, misunderstood your question. Nice job figuring it out though. I didn’t know there was a CUDA random state to save for deterministic training of non-deterministic models.

1 Like