Learning how to use DataParallel with a Transformer model

Hi all,

I am trying to learn about Pytorch; previously, I have only used Keras.

I worked through this example implementation of a transformer and got everything to work fine, but I noticed that I was only training on one of my two GPUs.

So, this code works:

model = Transformer(
    embedding_size, src_vocab_size, trg_vocab_size, src_pad_idx,
    num_heads, num_encoder_layers, num_decoder_layers, forward_expansion,
    dropout, max_len, device,
)
#model = nn.DataParallel( model, device_ids=[0, 1] )
model = model.to( device )

The first couple of epochs look like this:

[Epoch 0 / 10]
=> Saving checkpoint
Translated example sentence:
['pours', 'robe', 'jackson', 'bill', 'colonial', 'we', 'graduation', 'these', 'banners', 'magazine', 'lab', 'painter', 'wok', 'synchronized', 'building', 'camps', 'these', 'against', 'wok', 'cooler', 'other', 'washing', 'these', 'graduation', 'banners', 'rakes', 'lab', 'section', 'chipper', 'weight', 'cabin', 'patch', 'swimming', '1', 'vacuuming', 'pulled', 'shakes', 'hilly', 'substance', 'nursing', 'chipper', 'twirls', 'substance', 'books', 'other', 'washing', 'grabs', 'other', 'graduation', 'fan', 'these', 'washing', 'bongo', 'entrance', 'washing', 'magazine', 'lab', 'started', 'fountain', 'laughs', 'substance', 'magazine', 'other', 'rakes', 'started', 'magazine', 'magazine', 'lab', 'runway', 'mirrors', 'other', 'other', 'washing', 'groups', 'pushes', 'tiny', 'graduation', 'books', 'clothes', 'lab', 'brindle', 'fountain', 'lays', 'heads', 'other', 'seems', 'full', 'blankly', 'jackson', 'cabin', 'sweat', 'lab', 'practice', 'chested', 'meet', 'seems', 'posing', 'shoes', 'patrick', 'mud']
[Epoch 1 / 10]
=> Saving checkpoint
Translated example sentence:
['a', 'child', 'is', 'walking', 'in', 'a', 'red', 'shirt', 'is', 'playing', 'a', 'red', '.', '<eos>']
[Epoch 2 / 10]
=> Saving checkpoint
Translated example sentence:
['a', 'snowboarder', 'is', 'walking', 'down', 'a', 'bike', '.', '<eos>']
[Epoch 3 / 10]
=> Saving checkpoint
Translated example sentence:
['a', 'horse', 'is', 'walking', 'next', 'to', 'a', 'boat', 'next', 'to', 'a', 'boat', '.', '<eos>']

But if I uncomment that DataParallel line:

model = Transformer(
    embedding_size, src_vocab_size, trg_vocab_size, src_pad_idx,
    num_heads, num_encoder_layers, num_decoder_layers, forward_expansion,
    dropout, max_len, device,
)
model = nn.DataParallel( model, device_ids=[0, 1] )
model = model.to( device )

Everything falls apart after the first epoch:

[Epoch 0 / 10]
=> Saving checkpoint
Translated example sentence:
['international', 'drinking', 'thought', 'racing', 'statue', 'foreign', 'bikini', 'concert', 'straw', 'start', 'straw', 'gentlemen', 'hijab', 'buys', 'buys', 'boulder', 'surfboard', 'blazer', 'buys', 'hijab', 'kilts', 'hijab', 'stretched', 'clapping', 'aprons', 'gentlemen', 'sail', 'strobe', 'blazer', 'blazer', 'blazer', 'gentlemen', 'session', 'fierce', 'cave', 'olympic', 'plastic', 'mate', 'cavern', 'blazer', 'blazer', 'blazer', 'fortune', 'water', 'foreign', 'foods', 'kissing', 'water', 'gentlemen', 'unpaved', 'projected', 'siding', 'blinds', 'random', 'blues', 'celebration', 'tackled', 'selecting', 'blazer', 'completely', 'aprons', 'projected', 'aerial', 'water', 'blazer', 'blazer', 'blazer', 'projected', 'surfboards', 'blazer', 'holing', 'built', 'hijab', 'small', 'wind', 'projected', 'foreign', 'overlooking', 'goggles', 'operate', 'cavern', 'while', 'solo', 'kilts', 'sail', 'joy', 'theater', 'navigating', 'case', 'adults', 'coworker', 'gentlemen', 'aprons', 'gentlemen', 'surfboard', 'celebrating', 'subway', 'buys', 'stretched', 'blazer']
[Epoch 1 / 10]
=> Saving checkpoint
Translated example sentence:
['<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>']
[Epoch 2 / 10]
=> Saving checkpoint
Translated example sentence:
['<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>']
[Epoch 3 / 10]
=> Saving checkpoint
Translated example sentence:
['<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>']

It looks like maybe I’ve got some data placed on the wrong device, so it isn’t getting found during the gather step?

I’ve had a look at the “My recurrent network doesn’t work with data parallelism” section of the documentation. I don’t think I have that problem; I’m not using the referenced padding functions, and it looks like that issue would result in an error from the mismatch in dimensions, not this type of silent failure.

But there are many moving parts here that I’m still learning about, so I’m really not sure.

If anyone has any suggestions for how to go about troubleshooting this, I would really appreciate it!

Thanks for any help!