Data parallel: Expected all tensors to be on the same device, but found at least two devices

I’m training a segmentation vae everything works fine on one GPU but but when I do nn.DataParallel(model) for multi gpu I get the error on the bottom regarding the kl loss.

here’s a simplified version of my code:

class convBlock(nn.Module):
def __init__(self, inCh, nhid, nOp, pool=True,

self.enc1 = nn.Conv2d(inCh,nhid,kernel_size=ker,padding=1)
self.enc2 = nn.Conv2d(nhid,nOp,kernel_size=ker,padding=1) = nn.BatchNorm2d(inCh)

if pool:
self.scale = nn.AvgPool2d(kernel_size=pooling)
self.scale = nn.Upsample(scale_factor=pooling)
self.pool = pool
self.act = nn.ReLU()

def forward(self,x):
x = self.scale(x)
x =
x = self.act(self.enc1(x))
x = self.act(self.enc2(x))
return x

class uVAE(nn.Module):
def __init__(self, nlatent,unet=False,
nhid=8, ker=3, inCh=1,h=640,w=512):
super(uVAE, self).__init__()
self.latent_space = nlatent
self.unet = unet

if not self.unet:
### VAE Encoder with 4 downsampling operations
self.enc11 = nn.Conv2d(inCh,nhid,kernel_size=ker,padding=1)
self.enc12 = nn.Conv2d(nhid,nhid,kernel_size=ker,padding=1)

self.enc2 = convBlock(nhid,2*nhid,2*nhid,pool=True)
self.enc3 = convBlock(2*nhid,4*nhid,4*nhid,pool=True)
self.enc4 = convBlock(4*nhid,8*nhid,8*nhid,pool=True)
self.enc5 = convBlock(8*nhid,16*nhid,16*nhid,pool=True)

self.bot11 = nn.Conv1d(16*nhid,1,kernel_size=1)
self.bot12 = nn.Conv1d(int((h/16)*(w/16)),2*nlatent,kernel_size=1)

### Decoder with 4 upsampling operations
self.bot21 = nn.Conv1d(nlatent,int((h/64)*(w/64)),kernel_size=1)
self.bot22 = nn.Conv1d(1,nhid,kernel_size=1)
self.bot23 = nn.Conv1d(nhid,4*nhid,kernel_size=1)
self.bot24 = nn.Conv1d(4*nhid,16*nhid,kernel_size=1)

### U-net Encoder with 4 downsampling operations
self.uEnc11 = nn.Conv2d(inCh,nhid,kernel_size=ker,padding=1)
self.uEnc12 = nn.Conv2d(nhid,nhid,kernel_size=ker,padding=1)

self.uEnc2 = convBlock(nhid,2*nhid,2*nhid,pool=True,pooling=4)
self.uEnc3 = convBlock(2*nhid,4*nhid,4*nhid,pool=True,pooling=4)
self.uEnc4 = convBlock(4*nhid,8*nhid,8*nhid,pool=True)
self.uEnc5 = convBlock(8*nhid,16*nhid,16*nhid,pool=True)

### Joint U-Net + VAE decoder
if not self.unet:
self.dec5 = convBlock(32*nhid,8*nhid,8*nhid,pool=False)
self.dec5 = convBlock(16*nhid,8*nhid,8*nhid,pool=False)

self.dec4 = convBlock(16*nhid,4*nhid,4*nhid,pool=False)
self.dec3 = convBlock(8*nhid,2*nhid,2*nhid,pool=False,pooling=4)
self.dec2 = convBlock(4*nhid,nhid,nhid,pool=False,pooling=4)

self.dec11 = nn.Conv2d(2*nhid,nhid,kernel_size=ker,padding=1)
self.dec12 = nn.Conv2d(nhid,inCh,kernel_size=ker,padding=1)

self.act = nn.ReLU()
self.mu_0 = torch.zeros((1,nlatent)).to(device)
self.sigma_0 = torch.ones((1,nlatent)).to(device)

self.h = h
self.w = w

def vae_encoder(self,x):
### VAE Encoder
x = self.act(self.enc11(x))
x = self.act(self.enc12(x))
x = self.enc2(x)
x = self.enc3(x)
x = self.enc4(x)
x = self.enc5(x)

z = self.act(self.bot11(x.view(x.shape[0],x.shape[1],-1)))
z = self.bot12(z.permute(0,2,1))

return z.squeeze(-1)

def unet_encoder(self,x_in):
### Unet Encoder
x = []


return x

def decoder(self,x_enc,z=None):
if not self.unet:
### Concatenate latent vector to U-net bottleneck
x = self.act(self.bot21(z.unsqueeze(2)))
x = self.act(self.bot22(x.permute(0,2,1)))
x = self.act(self.bot23(x))
x = self.act(self.bot24(x))

x = x.view(x.shape[0],x.shape[1],
x =,x_enc[-1]),dim=1)
x = self.dec5(x)
x = self.dec5(x_enc[-1])

### Shared decoder
x =,x_enc[-2]),dim=1)
x = self.dec4(x)
x =,x_enc[-3]),dim=1)
x = self.dec3(x)
x =,x_enc[-4]),dim=1)
x = self.dec2(x)
x =,x_enc[-5]),dim=1)

x = self.act(self.dec11(x))
x = self.dec12(x)

return x

def forward(self, x):
kl = torch.zeros(1).to(device)
z = 0.
# Unet encoder result
x_enc = self.unet_encoder(x)

# VAE regularisation
if not self.unet:
emb = self.vae_encoder(x)

# Split encoder outputs into a mean and variance vector
mu, log_var = torch.chunk(emb, 2, dim=1)

# Make sure that the log variance is positive
log_var = softplus(log_var)
sigma = torch.exp(log_var / 2)

# Instantiate a diagonal Gaussian with mean=mu, std=sigma
# This is the approximate latent distribution q(z|x)
posterior = Independent(Normal(loc=mu,scale=sigma),1)
z = posterior.rsample()

# Instantiate a standard Gaussian with mean=mu_0, std=sigma_0
# This is the prior distribution p(z)
prior = Independent(Normal(loc=self.mu_0,scale=self.sigma_0),1)

# Estimate the KLD between q(z|x)|| p(z)
kl = KLD(posterior,prior).sum()

# Outputs for MSE
xHat = self.decoder(x_enc,z)

return kl, xHat

model = uVAE()
criterion = nn.BCELoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
model = nn.DataParallel(model)

for epoch in tqdm(range(epochs    
    for i, (image, mask) in enumerate(train_loader):
    #i, (image, mask) = next(enumerate(train_loader))
        image = image.cuda()
        mask = mask.cuda()

        # Forward pass
        kl, outputs = model(image)
        outputs = torch.sigmoid(outputs)
        rec_loss = criterion(input=outputs, target=mask)
        loss = kl/mask.shape[0] + rec_loss

        # Backward and optimize

I get the following error

RuntimeError                              Traceback (most recent call last)
Cell In[12], line 1
----> 1 t_h, model = train(model=model, criterion=criterion, optimizer=optimizer, epochs=20, train_loader=train_loader, test_loader=test_loader)

Cell In[11], line 26, in train(model, criterion, optimizer, train_loader, test_loader, epochs)
     23 mask = mask.cuda()
     25 # Forward pass
---> 26 kl, outputs = model(image)
     27 outputs = torch.sigmoid(outputs)
     29 rec_loss = criterion(input=outputs, target=mask)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.local/lib/python3.10/site-packages/torch/nn/parallel/, in DataParallel.forward(self, *inputs, **kwargs)
    169     return self.module(*inputs[0], **kwargs[0])
    170 replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
--> 171 outputs = self.parallel_apply(replicas, inputs, kwargs)
    return fun(p, q)
  File "/home/benx13/.local/lib/python3.10/site-packages/torch/distributions/", line 409, in _kl_normal_normal
    var_ratio = (p.scale / q.scale).pow(2)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!
I think your error is in this line:
kl = torch.zeros(1).to(device)
what does “device” hold?

“device” should hold cuda0

Just replace it with the same device. I suspect .to(device) and .cuda() invoke two different devices. So you want to use the same device for these variables.