I’ve been doing a lot of research (googling, stackoverflow, forums, etc.) on using the pack_padded_sequence method with multiple GPUs but I can’t seem to find a solution. I ooked at this:
https://pytorch.org/docs/stable/notes/faq.html#pack-rnn-unpack-with-data-parallelism
but it doesn’t seem to be helpful.
Here is the stacktrace:
Traceback (most recent call last):
File "train.py", line 166, in <module>
main(args)
File "train.py", line 78, in main
outputs = decoder(features, captions, lengths)
File "/home/jeff/.pyenv/versions/pytorch_tut/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/home/jeff/.pyenv/versions/pytorch_tut/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 114, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/home/jeff/.pyenv/versions/pytorch_tut/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 124, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/home/jeff/.pyenv/versions/pytorch_tut/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 65, in parallel_apply
raise output
File "/home/jeff/.pyenv/versions/pytorch_tut/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 41, in _worker
output = module(*input, **kwargs)
File "/home/jeff/.pyenv/versions/pytorch_tut/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/home/jeff/projects/python/pytorch/base_img_caption_model/model.py", line 42, in forward
packed = pack_padded_sequence(embeddings, lengths, batch_first=True)
File "/home/jeff/.pyenv/versions/pytorch_tut/lib/python3.6/site-packages/torch/onnx/__init__.py", line 57, in wrapper
return fn(*args, **kwargs)
File "/home/jeff/.pyenv/versions/pytorch_tut/lib/python3.6/site-packages/torch/nn/utils/rnn.py", line 124, in pack_padded_sequence
data, batch_sizes = PackPadded.apply(input, lengths, batch_first)
File "/home/jeff/.pyenv/versions/pytorch_tut/lib/python3.6/site-packages/torch/nn/_functions/packing.py", line 25, in forward
"{} (batch_size={}).".format(len(lengths), batch_size))
ValueError: Expected `len(lengths)` to be equal to batch_size, but got 512 (batch_size=256).
Code in training loop:
encoder = EncoderCNN(args.embed_size)
decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab), args.num_layers)
if torch.cuda.device_count() > 1:
encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)
encoder.to(device)
decoder.to(device)
Code in model:
def forward(self, features, captions, lengths):
"""Decode image feature vectors and generates captions."""
embeddings = self.embed(captions)
embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
packed = pack_padded_sequence(embeddings, lengths, batch_first=True)
hiddens, _ = self.lstm(packed)
outputs = self.linear(hiddens[0])
return outputs
The culprit is this line here:
packed = pack_padded_sequence(embeddings, lengths, batch_first=True)
nn.DataParallel is splitting the batch to equal sizes on both GPUs (batch of 256 on each GPU) but the lengths array remains at 512 which is the cause of the error. Is there a way to split the lengths array to separate batches to send to each GPU or is there another way to approach this situation to make it easier?
I saw this snippet of code from another post throwing the same error, and perhaps someone could use it as a sample to help debug:
import numpy as np
import torch
from torch.autograd import Variable
class RNNDataParallel(torch.nn.Module):
def __init__(self):
super(RNNDataParallel, self).__init__()
def forward(self, inputs, lengths):
packed = torch.nn.utils.rnn.pack_padded_sequence(inputs, lengths, batch_first=True)
return packed
model = RNNDataParallel()
model = torch.nn.DataParallel(model)
model = model.cuda()
inputs = Variable(torch.from_numpy(np.array([
[1, 2, 3],
[4, 5, 0],
])))
lengths = [3, 2]
packed = model(inputs, lengths)
print(packed)
Any help or suggestions would be appreciated!