DataParallel doesn't parallelize custom model across multiple GPU's with nn.ModuleList

I’m trying to write a variation of word2vec but on pytorch for multi-gpu support but I have been extremely unsuccessful in parallelizing it across multiple GPU’s. The code works for single GPU and torch.cuda.device_count() returns with 2 GPU’s but the second GPU has 0 memory being used.

class skipgram_discriminator(nn.Module):                                                                                                                                                                                                                                        
def __init__(self, vocabulary_size=150000, embedding_size=300, learning_rate=1e-4, batch_size=512):                                                                                                                                                                         
    """                                                                                                                                                                                                                                                                     
    Initialize a skipgram discriminator.                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                            
    - vocabulary size  is the unique words in the corpus to use to                                                                                                                                                                                                          
      make the Dense Matrices                                                                                                                                                                                                                                               
    - embedding size is the dimensionality of word vectors                                                                                                                                                                                                                  
    - learning_rate is learning rate for discriminator optimizer                                                                                                                                                                                                            
    """                                                                                                                                                                                                                                                                     
    super(skipgram_discriminator, self).__init__()                                                                                                                                                                                                                          
    self.embedding_size = embedding_size                                                                                                                                                                                                                                    
    self.batch_size = batch_size                                                                                                                                                                                                                                            
    self.dis_embeddings = nn.Embedding(vocabulary_size,self.embedding_size,\                                                                                                                                                                                                
            sparse=False)                                                                                                                                                                                                                                                   
    self.D_W2 = nn.Embedding(vocabulary_size,self.embedding_size,\                                                                                                                                                                                                          
            sparse=False)                                                                                                                                                                                                                                                   
    self.D_b2 = nn.Embedding(vocabulary_size,1)                                                                                                                                                                                                                             
    self.dis_embeddings.weight.data.uniform_(-0.5/self.embedding_size,\                                                                                                                                                                                                     
            0.5/self.embedding_size)                                                                                                                                                                                                                                        
    self.D_W2.weight.data.uniform_(-0.5/self.embedding_size,\                                                                                                                                                                                                               
            0.5/self.embedding_size)                                                                                                                                                                                                                                        
    self.D_b2.weight.data.zero_()                                                                                                                                                                                                                                           
    self.use_cuda = torch.cuda.is_available()                                                                                                                                                                                                                               
    self.criterion = nn.BCEWithLogitsLoss()                                                                                                                                                                                                                                 
    if self.use_cuda:                                                                                                                                                                                                                                                       
        self.disc_params = nn.ModuleList([self.dis_embeddings,self.D_W2,self.D_b2])                                                                                                                                                                                         
        for i, l in enumerate(self.disc_params):                                                                                                                                                                                                                            
            self.add_module(str(i), l)                                                                                                                                                                                                                                      
        self.disc_params = self.disc_params.cuda()                                                                                                                                                                                                                          
        self.disc_params = torch.nn.DataParallel(self.disc_params,                                                                                                                                                                                                          
                device_ids=range(torch.cuda.device_count()))                                                                                                                                                                                                                
    self.optimizer = optim.Adam(self.disc_params.parameters(), lr=learning_rate)                                                                                                                                                                                            
                                                                                                                                                                                                                                                                            
def forward(self,inputs,labels):                                                                                                                                                                                                                                            
    embedded_inputs = self.dis_embeddings(inputs)                                                                                                                                                                                                                           
    embedded_labels = self.D_W2(labels)                                                                                                                                                                                                                                     
    embedded_bias = self.D_b2(labels)                                                                                                                                                                                                                                       
    pos_score = torch.sum(embedded_inputs * embedded_labels, 1) + embedded_bias                                                                                                                                                                                             
    return pos_score
1 Like

Have you tried using

model = skipgram_discriminator()
model = nn.DataParallel(model).cuda()

and calling

CUDA_VISIBLE_DEVICES=0,1,2,3 python your_script.py

?

I have tried that as well but it didnt work for me with this class, it did however work when I was doing an nn.Sequential based model on the same computer. This is how I load my data, to be double sure I also do a prep function which essentially does
model = nn.DataParallel(model).cuda()

but is written like this
def prep(model):
’’’ Convert the model for GPU ‘’'
if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model,
device_ids=range(torch.cuda.device_count())).cuda()
cudnn.benchmark = True
return model`

and my model is initialized as follows:

''' Prepare for Cuda '''
if args.cuda:
    discriminator = prep(skipgram_discriminator(vocabulary_size=args.vsize,\
            embedding_size=args.esize, pretrained=pretrained_embeddings))
    generator = prep(skipgram_generator(vocabulary_size=args.vsize,\
            embedding_size=args.esize))

And the data for a forward pass is initialized like this looks like this:
mb_iwords_gpu = Variable((mb_iword.long().cuda() if args.cuda else mb_iword))
mb_cwords_gpu = Variable((mb_cword.long().cuda() if args.cuda else mb_cword))

But yet this still doesnt allocate any memory on the second GPU. I hope this clarifies the issue a bit more. Any help is greatly appreciated.

1 Like

I will answer to this, because I stumbled upon this problem 3 years later:

You can fix this by writing a generic function, which creates an object variable for each ModuleList element.
Assume, you have all layers ordered(first called layer first) in self.all_layers of your network, which needs to use your nn.ModuleList(self.all_layers).
By using the function

def _make_object_variables(obj, layers):
    for i, layer in enumerate(layers):
        setattr(obj, f"layer_{i}", layer)

by just calling _make_object_variables(self, self.all_layers) at the end of your init function. This _make_object_variables just does the above mentioned: Creating a object variable for each layer in the format “layer_1”, “layer_2”, …, “layer_n”

Instead of iterating over your self.all_layers in your forward pass, you just can get the layer by using

        for i, _ in enumerate(self.all_layers):
            layer = getattr(self, f"layer_{i}")
            .......

This worked for me in a UNET structure very well to fix Dataparallel training.

1 Like

Amazing. This was my problem. nn.ModuleList doesn’t seem to work when using nn.DataParallel as some GPUs have the weights in CPU mode and some have it on GPU mode. What a hacky way to do it though but it seems we have no choice. Thanks so much!

1 Like