I try to run my code on Two GPUs. However, it raises the error: tensors are on different GPUs. I try to print the device of each input and variable. The devices are consistent. I want to ask how to solve it. Thanks.
Could you provide a small code example so that we could have a look?
Thanks for your reply.
My code is a little complex. It is a meta-learning code. If possible, can I send the code to you via email? Thanks.
Send it via PM please.
Also, could you replace all data tensors etc. with random values?
If that’s too complicated, just write down the expected shape for the tensors, so that I don’t have to load a dataset.
I’ve sent you the link via message. Thanks a lot for your help.
Thanks for the code. It’s quite huge and somehow it cannot import
However, I skimmed through the code and one reason might be that you initialiye the states of your LSTM in
This will push the states to the default GPU and
DataParallel might throw the error.
You could fix this by checking, on which GPU the current LSTM parameters are and create the state according to this:
device = next(self.parameters()).device self.lstm_c0 = torch.zeros(...).to(device)
Well, this would work in
0.4.0. Since you are using an older version, you could try the following:
param = next(self.parameters()) is_cuda = param.is_cuda if is_cuda: cuda_device = param.get_device() self.lstm_c0 = Variable(torch.zeros(...)).cuda(cuda_device)
The code isn’t tested, so let me know, if you encounter any errors.
Thanks a lot for your help. I am sorry that you can not run it. I can run the code on my server.
I also guess the error is caused by self.lstm_c0 and self.lstm_h0. I tried the code you provided, but it still raises the runtimeerror: tensors are on different GPUs.
param = next(self.parameters())
self.lstm_c0 = Variable(torch.zeros(self.nParams, self.lstm.hidden_size), requires_grad=False).cuda(cuda_device)
self.lstm_h0 = Variable(torch.zeros(self.nParams, self.lstm.hidden_size), requires_grad=False).cuda(cuda_device)
I think we should move this conversation to PMs, as nobody can follow the thread without your code.