How to use DataParallel in Pytorch while using parts of model?

I am working on training a Variational Autoencoder(VAE) for disentangling features of variation from a dataset. In order to do that, encode and decode layers of VAE are used separately while defining the loss function; since common factor of two images with same label in the dataset are replaced. Using such method, only one GPU is being used while using DataParallel on my model. I have trained other models on my current setup which successfully use both GPUs. How can I correct this? I also found a somewhat relevant question on discussion page of pytorch but without any answers here.

I’m unsure what exactly you’re trying to do, are you trying to run different parts of your model on different GPU? If so, @ailzhang’s answer here might be relevant.

If not, can you describe what you’re trying to do in more detail?

Hi @daemonslayer, please make sure you merge your encoder model & decoder model to a single model and then apply DP on model. If this doesn’t solve your problem, feel free to paste a minimal repro of your script so that we can help.

Hi @ailzhang, thanks for the reply. Here is my model code :

#!/usr/bin/env python

import torch
import torch.nn as nn
from torch.autograd import Variable

import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class VAE(nn.Module):
    def __init__(self, nc, ngf, ndf, latent_variable_size):
        super(VAE, self).__init__()

        self.nc = nc
        self.ngf = ngf
        self.ndf = ndf
        self.latent_variable_size = latent_variable_size

        # encoder
        self.e1 = nn.Conv2d(nc, ndf, 4, 2, 1)
        self.bn1 = nn.BatchNorm2d(ndf)

        self.e2 = nn.Conv2d(ndf, ndf*2, 4, 2, 1)
        self.bn2 = nn.BatchNorm2d(ndf*2)

        self.e3 = nn.Conv2d(ndf*2, ndf*4, 4, 2, 1)
        self.bn3 = nn.BatchNorm2d(ndf*4)

        self.e4 = nn.Conv2d(ndf*4, ndf*8, 4, 2, 1)
        self.bn4 = nn.BatchNorm2d(ndf*8)

        self.e5 = nn.Conv2d(ndf*8, ndf*8, 4, 2, 1)
        self.bn5 = nn.BatchNorm2d(ndf*8)

        self.fc1 = nn.Linear(ndf*8*4*4, latent_variable_size)
        self.fc2 = nn.Linear(ndf*8*4*4, latent_variable_size)
        self.fc3 = nn.Linear(ndf*8*4*4, latent_variable_size)

        # decoder
        self.d1 = nn.Linear(latent_variable_size, ngf*8*2*4*4)

        self.up1 = nn.UpsamplingNearest2d(scale_factor=2)
        self.pd1 = nn.ReplicationPad2d(1)
        self.d2 = nn.Conv2d(ngf*8*2, ngf*8, 3, 1)
        self.bn6 = nn.BatchNorm2d(ngf*8, 1.e-3)

        self.up2 = nn.UpsamplingNearest2d(scale_factor=2)
        self.pd2 = nn.ReplicationPad2d(1)
        self.d3 = nn.Conv2d(ngf*8, ngf*4, 3, 1)
        self.bn7 = nn.BatchNorm2d(ngf*4, 1.e-3)

        self.up3 = nn.UpsamplingNearest2d(scale_factor=2)
        self.pd3 = nn.ReplicationPad2d(1)
        self.d4 = nn.Conv2d(ngf*4, ngf*2, 3, 1)
        self.bn8 = nn.BatchNorm2d(ngf*2, 1.e-3)

        self.up4 = nn.UpsamplingNearest2d(scale_factor=2)
        self.pd4 = nn.ReplicationPad2d(1)
        self.d5 = nn.Conv2d(ngf*2, ngf, 3, 1)
        self.bn9 = nn.BatchNorm2d(ngf, 1.e-3)

        self.up5 = nn.UpsamplingNearest2d(scale_factor=2)
        self.pd5 = nn.ReplicationPad2d(1)
        self.d6 = nn.Conv2d(ngf, nc, 3, 1)

        self.leakyrelu = nn.LeakyReLU(0.2)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def encode(self, x):
        h1 = self.leakyrelu(self.bn1(self.e1(x)))
        h2 = self.leakyrelu(self.bn2(self.e2(h1)))
        h3 = self.leakyrelu(self.bn3(self.e3(h2)))
        h4 = self.leakyrelu(self.bn4(self.e4(h3)))
        h5 = self.leakyrelu(self.bn5(self.e5(h4)))
        h5 = h5.view(-1, self.ndf*8*4*4)

        return self.fc1(h5), self.fc2(h5), self.fc3(h5)

    def reparameterize(self, mu, logvar):
        # std = logvar.mul(0.5).exp_()
        # eps = torch.cuda.FloatTensor(std.size()).normal_()
        # eps = Variable(eps).cuda()
        # return eps.mul(std).add_(mu)
        if self.training:
            std = torch.exp(0.5*logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, common_factor, varying_factor):
        h1_c = self.relu(self.d1(common_factor))
        h1_v = self.relu(self.d1(varying_factor))
        # h1 = torch.cat([h1_c, h1_v], dim=0)
        h1 = h1_c + h1_v
        h1 = h1.view(-1, self.ngf*8*2, 4, 4)
        h2 = self.leakyrelu(self.bn6(self.d2(self.pd1(self.up1(h1)))))
        h3 = self.leakyrelu(self.bn7(self.d3(self.pd2(self.up2(h2)))))
        h4 = self.leakyrelu(self.bn8(self.d4(self.pd3(self.up3(h3)))))
        h5 = self.leakyrelu(self.bn9(self.d5(self.pd4(self.up4(h4)))))

        return self.sigmoid(self.d6(self.pd5(self.up5(h5))))

    def get_latent_var(self, x):
        common_factor, mu, logvar = self.encode(x.view(-1, self.nc, self.ndf, self.ngf))
        z = self.reparametrize(mu, logvar)
        return z

    def forward(self, x):
        common_factor, mu, logvar = self.encode(x.view(-1, self.nc, self.ndf, self.ngf))
        z = self.reparametrize(mu, logvar)
        res = self.decode(z)
        return res, mu, logvar



model = VAE(nc=3, ngf=128, ndf=128, latent_variable_size=500)
model = nn.DataParallel(model).to(device)

model.train()

train_data_1 = np.random.randn(10, 20, 3, 128, 128)
train_data_2 = np.random.randn(10, 20, 3, 128, 128)

train_data_1 = torch.from_numpy(train_data_1).float()
train_data_2 = torch.from_numpy(train_data_2).float()

# train_data_1 = torch.unsqueeze(train_data_1, 0)
# train_data_2 = torch.unsqueeze(train_data_2, 0)

def vae_loss_function(image1, decoded_1, image2, decoded_2, mu_1, logvar_1):
	MSE = nn.MSELoss()
	mse_loss_1 = MSE(image1, decoded_1.detach())
	mse_loss_2 = MSE(image2, decoded_2.detach())
	# print("Shape of image is {} and decoded tensor is {}".format(image1.shape, decoded_1.shape))
	# mse_loss_1 = torch.sum((image1 - decoded_1)**2) / image1.data.nelement()
	# mse_loss_2 = torch.sum((image2 - decoded_2)**2) / image2.data.nelement()        
	kl_div = torch.sum(-0.5 * torch.sum(1 + 2 * logvar_1 - mu_1**2 - torch.exp(2*logvar_1),1))

	total_loss = mse_loss_1 + mse_loss_2 + kl_div
	return total_loss

num_epochs = 100
lr = 0.001

optimizer = torch.optim.Adam(model.parameters(), lr)

train_data_1 = train_data_1.to(device)
train_data_2 = train_data_2.to(device)

for epoch in range(num_epochs):
	for batch_idx, image in enumerate(train_data_1):
		# print('hello batch {}'.format(batch_idx))
		common_factor_1, mu_1, logvar_1 = model.module.encode(image)
		reparam_out_1 = model.module.reparameterize(mu_1, logvar_1)
		common_factor_2, mu_2, logvar_2 = model.module.encode(train_data_2[batch_idx])
		reparam_out_2 = model.module.reparameterize(mu_2, logvar_2)

		decoded_1 = model.module.decode(common_factor_2, reparam_out_1)
		decoded_2 = model.module.decode(common_factor_1, reparam_out_2)

		loss = vae_loss_function(image, decoded_1, train_data_2[batch_idx], decoded_2, mu_1, logvar_1)
		
		optimizer.zero_grad()
		loss.backward()
		train_loss = loss.item()
		optimizer.step()

		if batch_idx % 10 == 0:
			print(
				'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tlr: {:.6f}'.format(
				epoch+1, batch_idx * len(train_data_1), len(train_data_1),
				100. * batch_idx / len(train_data_1),
				loss.item() / len(train_data_1), lr)
			)

This code is not able to use more than one gpu available to it. What is wrong in here?

Hi @daemonslayer, the problem is that you are using model.module to train which is limited to the default device. In DP, you should specify all forward behavior in forward() function and just call model(input) instead of calling model.encode/decode separately.