Loading a model + RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Hi everyone,

I want to load my previous saved model and continue training. This is the code for saving the model. The network is being trained on GPU (and works well there).

# Save model
torch.save({
       'epoch': epoch,
       'model_state_dict': model.state_dict(),
       'optimizer_state_dict': optimizer.state_dict(),
       'loss': loss, }, f'{chkp_Path}/epoch{epoch}_{datetime.now().strftime("%Y%m%d-%H%M%S")}.pt') 

During the loading phase, I load the model on CPU (to prevent out of memory issue, like) and then transfer it back to GPU:

chkPath = 'Path_to_chk_file.pt'
checkpoint = torch.load(chkPath, map_location='cpu')
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
losslogger = checkpoint['loss']

device = torch.device(str("cuda:0") if torch.cuda.is_available() else "cpu") 
model = model.to(device) # the network should be on GPU, next(model.parameters()).is_cuda==False

Then , when I start the training phase, I receive the following error. Is there any way to find out where those torch tensors on CPU come from? BTW, the data loader is on GPU during the training.

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

The above error happen at this line and after loss calculaiton

opt.step()

These lines of code:

device = torch.device(str("cuda:0") if torch.cuda.is_available() else "cpu") 
model = model.to(device) # the network should be on GPU, next(model.parameters()).is_cuda==False

look interesting, as it seems that parameters are not passed to the GPU.
Could you make sure device is indeed cuda:0 (I assume so)?
If so, could you post the model definition, please?

Thanks for your reply. Device is “cuda:0”. I have explicitly set it to 'cuda:0" and also checked the GPU load. It is running on GPU.
Please find the model. It is a (Vibrational) Auto Encoder for 3D images (adopted from here).

class Encoder(nn.Module):
    def __init__(self, latent_dims, VE):
        super(Encoder, self).__init__()
        self.CNN = nn.Sequential(
                        nn.Conv3d(1, 64, kernel_size=5, padding=2, stride=1),
                        nn.ReLU(),
                        nn.MaxPool3d(2),
                       
                        nn.Conv3d(64, 128, kernel_size=5, padding=2, stride=1),
                        nn.ReLU(),
                        nn.MaxPool3d(2),
                                    
                        nn.Conv3d(128, 256, kernel_size=5, padding=2, stride=1),
                        nn.ReLU(),
                        nn.MaxPool3d(2),
                            
                        nn.Conv3d(256, 256, kernel_size=5, padding=2, stride=1),
                        nn.LeakyReLU(0.1),
                        nn.MaxPool3d(2),
                       
                        nn.Flatten(),
            )
        
        self.linear1 = nn.Linear(8**3*256, latent_dims)
        self.linear2 = nn.Linear(8**3*256, latent_dims)

        self.N = torch.distributions.Normal(0, 1)
       
        self.kl = 0
        self.VE = VE
        
    def forward(self, x):
        if self.VE:
            x = self.CNN(x)
            mu =  self.linear1(x)
            sigma = torch.exp(self.linear2(x)) # exp: to ensure the result is positive
            z = mu + sigma*self.N.sample(mu.shape)
            self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
        else:
            z = self.CNN(x)
            z = self.linear1(z)
        return z
    
class Decoder(nn.Module):
    def __init__(self, latent_dims):
        super(Decoder, self).__init__()
        self.Linear = nn.Linear(latent_dims, 8**3*256)
        self.CNN = nn.Sequential(
                            nn.ConvTranspose3d(256, 256, kernel_size=4, stride=2, padding=1),
                            nn.ReLU(),
                           
                            nn.ConvTranspose3d(256, 128, kernel_size=4, stride=2, padding=1),
                            nn.ReLU(),
                            
                            nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, padding=1),
                            nn.ReLU(),
                            
                            nn.ConvTranspose3d(64, 1, kernel_size=4, stride=2, padding=1),
                            nn.Sigmoid()
                          )
    def forward(self, z):
        z = self.Linear(z)
        z = z.view((-1, 256, 8, 8, 8))
        z = self.CNN(z)
        return z
    
class VariationalAutoencoder(nn.Module):
    def __init__(self, latent_dims, VE):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = Encoder(latent_dims, VE)
        self.decoder = Decoder(latent_dims)

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

Tnx for your help.

Thanks for the code! Your model seems to be fine in my setup:

# save    
model = VariationalAutoencoder(1, 1)
model.to('cuda:0')
print(next(model.parameters()).device)
# > cuda:0

torch.save(model.state_dict(), 'tmp.pt')


# load in another script
import torch
import torch.nn as nn

model = VariationalAutoencoder(1, 1)
checkpoint = torch.load('tmp.pt', map_location='cpu')
model.load_state_dict(checkpoint)
# > <All keys matched successfully>

model.to('cuda:0')
print(next(model.parameters()).device)
# > cuda:0

Could you check, if you are able to run into the same issue using my code snippet and if not where the difference might be?

Tnx for your prompt reply. I am doing very similar to your code. After the loading of the checkpoint, the model is on GPU (print(next(model.parameters()).device)==cuda:0). However, when the training code reaches to optimizer.step(), inside of the training function, it throws the error. It means that input and model are both at the same machine because loss can be computed for the first minibatch at least. However, something inside of the loaded optimizer should be wrong. It seems that some of the Optmizer parameters should be in CPU and some of the others in GPU simultaneously. I am not expert how pytorch optimizer is working. Is there any similar code to model.to(device) for optimizer (like opt.to(device)) available?

I googled the idea of moving optimizer parameters from cpu to cuda:0 and found this bunch of code from here:

def optimizer_to(optim, device):
    # move optimizer to device
    for param in optim.state.values():
        # Not sure there are any global tensors in the state dict
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)

Surprisingly, it solved my problem. :slightly_smiling_face: But I do not why the optimizer parameters still remain in ‘cpu’ when the uploaded model moves from cpu to gpu? any idea?

This is how I solved the problem:

def load_checkpoint(model, optimizer, losslogger, device, filename='checkpoint.pt'):
    # https://discuss.pytorch.org/t/loading-a-saved-model-for-continue-training/17244/3
    start_epoch = 0
    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename, map_location=device)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        losslogger = checkpoint['loss']
        print("=> loaded epoch: {} of checkpoint: '{}' "
                  .format(filename, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(filename))
        exit()
    return model, optimizer, start_epoch, losslogger  

 if Phase=='continue_training':
        loss = 0
        model, optimizer, start_epoch, loss = load_checkpoint(model, optimizer, loss, 'cpu', chkPath)
        model = model.to(device)         # move back model to GPU
        assert next(model.parameters()).device == 'cuda:0'        
        optimizer_to(optimizer, device)  # *** move optimizer to GPU ***
2 Likes

Good to hear you’ve solved the issue!
I think loading the optimizer on the CPU via map_location is causing the trouble.

Thank you for the discussion.
That’s right. The Pytorch optimizer is not straight forward and something internally happens. Maybe there is an internal bug…

I got the same issue today, and managed to fix it by changing the order of the 3 steps to the following:

  1. call model.to(device) first, then
  2. model.load_state_dict() second.
  3. opt.load_state_dict() last.

I loaded the state_dict of the optimizer as the last step, but switching 2 and 3 doesn’t seem to make a difference. I guess it is important to move the model as the 1st step. Judging from your original post this might be the issue, maybe you could give it a try.

Btw I also added map_location="cpu" to torch.load().

1 Like

This solution worked for me. I didn’t add,

But changed the order as suggested in reply by @guangzhi .

The solution of @guangzhi works for me.