Element-Wise Multiplication - RuntimeError: expected device cpu but got device cuda:0

Hi, I run into some error that I couldn’t figure it out where’s wrong. Thanks for helping in advance!

I’m trying to apply gradients penalty on discriminator loss. However, I struggle with the line 142-144. I need to do the interpolate operation, but any operation I apply on fake_img would give me error “expected device cpu but got device cuda:0”.

I did print out the type and size for real_img, fake_img, and alpha. It turns out they are all type tensor, and size [16, 3, 128, 128]. I don’t know where went wrong.

The operation with real_img is good. Just when I deal with fake_img, that error shows up.

114             ############################
115             # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
116             ###########################
117             real_img = Variable(target)
118             target_real_out = Variable(torch.ones(batch_size, 1))
119 
120             real_img.to(device)
121             target_real_out.to(device)
122             print(type(real_img), real_img.size())
123 
124             z = Variable(data)
125             target_fake_out = Variable(torch.zeros(batch_size, 1))
126 
127             z.to(device)
128             target_fake_out.to(device)
129 
130             fake_img = netG(z)
131             print(type(Variable(fake_img.data)), Variable(fake_img.data).size())
132 
133             netD.zero_grad()
134             real_out = netD(real_img)
135             fake_out = netD(Variable(fake_img.data))
136 
137             alpha = torch.rand(batch_size, 1)
138             alpha = alpha.expand(batch_size,
                                                       real_img.nelement()//batch_size).contiguous().view(real_img.size())
139             alpha.to(device)
140             print(type(alpha), alpha.size())
141 
142             interpolates = (1 - alpha) * fake_img
143             # interpolates = alpha * real_img + ((1 - alpha) * fake_img)
144             interpolates.to(device)
145 
146             interpolates = Variable(interpolates, requires_grad=True)
147             disc_interpolates = netD(interpolates)
148             gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
149                     grad_outputs=torch.ones(disc_interpolates.size()).to(device),
150                     create_graph=True, retain_graph=True, only_inputs=True)[0]
151             gradients = gradients.view(gradients.size(0), -1)
152             gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 10
153 
154             d_loss = -(torch.mean(real_out) - torch.mean(fake_out)) + gradient_penalty
155 
156 
157 
158             if torch.cuda.device_count() > 1:
159                 d_loss.mean().backward(retain_graph=True)
160             else:
161                 d_loss.backward(retain_graph=True)
162             optimizerD.step()
# generator parameters: 1769728
# discriminator parameters: 14499401
Let's use 4 GPUs!
Reading checkpoint...
Not find any generator model! Training from scratch with  mse  loss!
RTGAN Training Starts!
  0%|                                                                        | 0/632 [00:00<?, ?it/s]<class 'torch.Tensor'> torch.Size([16, 3, 128, 128])
<class 'torch.Tensor'> torch.Size([16, 3, 128, 128])
<class 'torch.Tensor'> torch.Size([16, 3, 128, 128])
  0%|                                                                        | 0/632 [00:09<?, ?it/s]
Traceback (most recent call last):
  File "train_rtgan.py", line 142, in <module>
    interpolates = (1 - alpha) * fake_img
RuntimeError: expected device cpu but got device cuda:0

Your “alpha” variable has been sent to the gpu in line 138, you need to do the same with “fake_img” variable or you can keep both of them in cpu.

Thanks for your reply.

fake_img is from netG(z), where netG is passed to multiple GPUs. Please see line 46, 58, 63.

 37     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 38 
 39     train_set = DatasetFromFolder(TRAIN_DATA_PATH, resize=RESIZE)
 40     val_set = DatasetFromFolder(VAL_DATA_PATH, resize=RESIZE)
 41     train_loader = DataLoader(dataset=train_set, num_workers=TRAIN_NUM_WORKERS, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
 42     val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)
 43 
 44     writer = SummaryWriter(log_dir='logs/current_model')
 45 
 46     netG = Generator()
 47     print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
 48     netD = Discriminator()
 49     print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))
 50 
 51     generator_criterion = GeneratorLoss()
 52     # adversarial_criterion = nn.BCELoss()
 53     adversarial_criterion = nn.BCEWithLogitsLoss()
 54     mse_loss = nn.MSELoss()
 55 
 56     if torch.cuda.device_count() > 1:
 57         print("Let's use", torch.cuda.device_count(), "GPUs!")
 58         netG = nn.DataParallel(netG)
 59         netD = nn.DataParallel(netD)
 60         generator_criterion = nn.DataParallel(generator_criterion)
 61         adversarial_criterion = nn.DataParallel(adversarial_criterion)
 62         mse_loss = nn.DataParallel(mse_loss)
 63     netG.to(device)
 64     netD.to(device)
 65     generator_criterion.to(device)
 66     adversarial_criterion.to(device)
 67     mse_loss.to(device)

If netG is running on multiple GPU, isn’t its output - fake_img, also on multiple GPU? Should I declare it again by fake_img.to(device)? I did try, here’re the changes I made and the corresponding error (same error):

142             fake_img.to(device)
143             interpolates = (1 - alpha) * fake_img  
144             # interpolates = alpha * real_img + ((1 - alpha) * fake_img)  
145             interpolates.to(device) 

and

Let's use 4 GPUs!
Reading checkpoint...
RTResNet-MSE model restore success! Start training in RTGAN with  mse  loss!
RTGAN Training Starts!
 0%|                                                                        | 0/632 [00:00<?, ?it/s]
<class 'torch.Tensor'> torch.Size([16, 3, 128, 128])
<class 'torch.Tensor'> torch.Size([16, 3, 128, 128])
<class 'torch.Tensor'> torch.Size([16, 3, 128, 128])
 0%|                                                                        | 0/632 [00:11<?, ?it/s]
Traceback (most recent call last):
    File "train_rtgan.py", line 143, in <module>
        interpolates = (1 - alpha) * fake_img
RuntimeError: expected device cpu but got device cuda:0

Thank you again for helping me. Appreciate

Here’s some update on my debugging process.

What I found is that only fake_img sent to device cuda:0, but NOT alpha and real_img. I don’t understand why because I indeed send alpha.to(device) and real_img.to(device).

For fake_img, no matter I did fake_img.to(device) or not, it will always be on cuda:0. Even when I write fake_img.to(device=‘cpu’), its output still shows that fake_img on cuda:0. I think the reason for that because I fake_img is an output from netG, which has been sent to multiple GPUs by nn.DataParallel(). When I test outputs from netD, they are also on cuda:0 device.

I tried to print the size of real_img, fake_img, and alpha, as well as their values.

real_img size torch.Size([16, 3, 128, 128])
fake_img size torch.Size([16, 3, 128, 128])
alpha size torch.Size([16, 3, 128, 128])

real_img tensor([[[[0.8157, 0.8196, 0.8157,  ..., 0.8980, 0.9020, 0.9059],...[0.5608, 0.5647, 0.5608,  ..., 0.8314, 0.8275, 0.8196]]]])
fake_img tensor([[[[0.1948, 0.3402, 0.4361,  ..., 0.3754, 0.9775, 0.9215],...[0.5765, 0.6778, 0.4765,  ..., 0.0668, 0.1717, 0.1910]]]], device='cuda:0', grad_fn=<GatherBackward>)
alpha tensor([[[[0.2796, 0.2796, 0.2796,  ..., 0.2796, 0.2796, 0.2796],...[0.3608, 0.3608, 0.3608,  ..., 0.3608, 0.3608, 0.3608]]]])

Later I try to code on python prompt, instead of running a python script. The code is exactly copy & paste from the script that throws me error, however this time the real_img & alpha are on cuda:0. There must be an erro on my script that makes the real_img & alpha on cpu even though I send them to device. Below is the code I typed on python prompt:

>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> train_set = DatasetFromFolder('/dockerx/rtgan_data/folder5_all', if_resize=True, resize_val=256)
>>> train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=16, shuffle=True)
>>> for index, (data, target) in enumerate(train_loader):
...     real_img = Variable(target)
...     real_img.to(device)
...     print(real_img)

tensor([[[[0.8157, 0.8196, 0.8235,  ..., 0.9020, 0.9020, 0.9059],......[0.5686, 0.5686, 0.5725,  ..., 0.8235, 0.8235, 0.8235]]]], device='cuda:0')

So in this promot, the real_img is successfully sent to cuda:0. I don’t understand why in my sciprt, real_img, or alpha is not sent to cuda:0.

Thank you for helping!

Can you try maybe

alpha = alpha.to(device)

I think the ‘to’ function is not in-place…this is what I get when I run in my termnal…

>>> a = torch.tensor([1., 2., 3.])
>>> a.to('cuda:0')
tensor([1., 2., 3.], device='cuda:0')
>>> a
tensor([1., 2., 3.])
>>> a = a.to('cuda:0')
>>> a
tensor([1., 2., 3.], device='cuda:0')
>>> 

But it is different when I use a layer (Module)…

>>> linear = nn.Linear(2, 2)
>>> linear.weight.data
tensor([[ 0.3646,  0.5939],
        [-0.2984,  0.4559]])
>>> linear.to('cuda:0')
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight.data
tensor([[ 0.3646,  0.5939],
        [-0.2984,  0.4559]], device='cuda:0')
>>> 

So, yeah, a bit confusing, which is why I have always treated the ‘to’ function as NOT in-place and always used

model = model.to('cuda:0')
1 Like

Thank you! It did solve the problem:)

It is a good lesson for me - I just learned that ‘to’ function is not in-place for tensor, but is for model layers… From now on I will just treat it as not in-place function.

Thank you again