DataParallel not working as expected

Hello all!

I have a custom model class which is basically a seq2seq network (encoder-decoder).
I have 4 GPUs available and would need to parallelize the training.

Here is the main part of my training file:

model = Model(in_size, num_rnn_units_per_layer, out_size, num_rnn_layers, dropout, embed_size=embed_size)
if (is_cuda): model.cuda()

model = nn.DataParallel(model, dim=0).cuda()

num_gpu = torch.cuda.device_count()

hidden_enc = model.module.encoder.init_hidden(batch_size, num_gpu)
hidden_dec = model.module.decoder.init_hidden(batch_size, num_gpu)

for x, y in train_reader.iter():
    print x.shape # input shape is "bsz x seq_len"
    output, hidden_enc, hidden_dec = model(x, hidden_enc, hidden_dec)
    print output.shape # output shape is "bsz/num_gpu x seq_len"

Where init_hidden is a function which initializes the hidden layer of the encoder/decoder:

def init_hidden(self, bsz, num_gpu):
    bsz /= num_gpu
    return Variable(torch.randn(2 * self.n_layers, bsz, self.hidden_size)).cuda()

As mentioned in the comments of the training file, what happens is:
After a forward call of the model, on the correct input, I get partial output.
Like, let input shape be bsz x seq_len ; my output shape is: bsz/num_gpu x seq_len

I can say the num_gpu division because I’ve tried allocating different number of GPUs to the training, and it all generalizes to this. (Works fine with 1 GPU, of course).

Any help on this is much appreciated! :smiley:

How did you get output? It’s not clear how it’s generated from this code.

@richard , Sorry for the typo in the provided code, the line should be this:

output, hidden_enc, hidden_dec = model(x, hidden_enc, hidden_dec)

I’ll update the line in the provided code snippet.

Also, I was able to get past the issue by making the forward pass of model taking only one argument i.e “x” and nothing else (in my provided code, removing both “hidden_enc” and “hidden_dec”) to make DataParallel return output of the correct dimensions.

Maybe this can be fixed (first reproduced :stuck_out_tongue:) in the upcoming versions of pytorch.