Tensors on different devices when using DataParallel: Expected all tensors to be on the same device

Hi, I’m trying to adapt some code to run on multiple GPUs. To do so, I’ve followed this example. It seemed quite simple and seamless. I was able to make it work on a simple dummy example.

However, I’m unable to make it work on my actual code. I keep getting the following error:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

Here’s a simplified version of my code.
The Model:

class VGAN_generator(nn.Module):

    def __init__(self, z_dim, hidden_dim, x_dim, layers, col_type, col_ind ,condition=False,c_dim=0):
        super(VGAN_generator, self).__init__()
        self.input = nn.Linear(z_dim+c_dim, hidden_dim)
        self.inputbn = nn.BatchNorm1d(hidden_dim)
        self.hidden = []
        self.BN = []
        self.col_type = col_type
        self.col_ind = col_ind
        self.x_dim = x_dim
        self.c_dim = c_dim
        self.condition = condition
        for i in range(layers):
            fc = nn.Linear(hidden_dim, hidden_dim)
            setattr(self, "fc%d"%i, fc)
            bn = nn.BatchNorm1d(hidden_dim)
            setattr(self, "bn%d"%i, bn)
        self.output = nn.Linear(hidden_dim, x_dim)
        self.outputbn = nn.BatchNorm1d(x_dim)
    def forward(self, z):
        z = self.input(z)
        z = self.inputbn(z)
        z = torch.relu(z)
        for i in range(len(self.hidden)):
            z = self.hidden[i](z)
            z = self.BN[i](z)
            z = torch.relu(z)
        x = self.output(z)
        x = self.outputbn(x)
        output = []
        for i in range(len(self.col_type)):
            sta = self.col_ind[i][0]
            end = self.col_ind[i][1]
            if self.col_type[i] == 'binary':
                temp = torch.sigmoid(x[:,sta:end])
            elif self.col_type[i] == 'normalize':
                temp = torch.tanh(x[:,sta:end])
            elif self.col_type[i] == 'one-hot':
                temp = torch.softmax(x[:,sta:end], dim=1)
            elif self.col_type[i] == 'gmm':
                temp1 = torch.tanh(x[:,sta:sta+1])
                temp2 = torch.softmax(x[:,sta+1:end], dim=1)
                temp = torch.cat((temp1,temp2),dim=1)
            elif self.col_type[i] == 'ordinal':
                temp = torch.sigmoid(x[:,sta:end])
        output = torch.cat(output, dim = 1)
        return output

The Training

def V_Train(G, epochs, lr, dataloader, z_dim, device, steps_per_epoch = None):

    print("Let's use", torch.cuda.device_count(), "GPUs!")
    if torch.cuda.device_count() > 1:
        G = nn.DataParallel(G, device_ids=[0,1])


    G_optim = optim.Adam(G.parameters(), lr=lr, weight_decay=0.00001)

    # the default # of steps is the # of batches.
    if steps_per_epoch is None:
        steps_per_epoch = len(dataloader)

    for epoch in range(epochs):
        it = 0
        while it < steps_per_epoch:
            for x_real in dataloader:
                x_real = x_real.to(device)

                z = torch.randn(x_real.shape[0], z_dim)
                z = z.to(device)
                x_fake = G(z) # ERROR HAPPENS HERE
                # MORE CODE BELOW
                # [...]
    return G

Finally, here’s how the training is called:

device = torch.device("cuda:0" if GPU else "cpu")
V_Train(G, epochs, lr, dataloader, z_dim, device, steps_per_epoch)

I’ve had a hard time finding someone online with the same issue. I’ve seen mentions of tensor operations in the model __init__ causing issues. I don’t seem to have any of that here. My guess is that the batch norm is causing synchronization issues between devices. I wouldn’t know why or how to fix it.

I can also note that my code runs without issues on a single GPU

Thanks for your help!

You are creating self.hidden and self.BN as plain Python lists, which won’t work.
To properly register modules you would need to use nn.ModuleList instead and it’ll work.

1 Like

This is exactly it! Thank you.

I was also storing lists of Modules inside Python dictionaries. I fixed it by using nn.ModuleDict to store nn.ModuleList objects.


1 Like