I am looking for clarification on the best way to use DataParallel with attention layers. As an example, MultiheadAttention expects inputs which have shape (L,N,E) where L is the length of the sequence, N is the batchsize, and E is the embedding size. The fact that the batch size is NOT the first dimension leads to problem when using DataParallel. To work around this I am transposing the dimension, see example below:
import torch
import torch.nn as nn
class AttnParallel(nn.Module):
def __init__(self, dim, num_heads):
super(AttnParallel,self).__init__()
self.attn = nn.MultiheadAttention(dim, num_heads, dropout=0, bias=False)
def forward(self, h, mask):
print("h has shape:", h.shape)
print("mask has shape", mask.shape)
h = h.transpose(0,1).contiguous()
h = self.attn(h,h,h, key_padding_mask=mask)[0]
h = h.transpose(0,1).contiguous()
return h
# create model
dim =4
num_head=2
device = torch.device("cuda")
mod = AttnParallel(dim, num_head)
mod = nn.DataParallel(mod.to(device))
# create data
bsz = 16
L = 5
h = torch.rand(bsz,L,dim)
mask = torch.zeros(bsz,L).bool()
mask[0,1] = True
mask[2,4] = True
# forward
h = mod(h,mask)
I have a few questions:

My understanding is that when using DataParralel, whatever tensors I feed to the forward() function will be chunked over the first dimension into 8 pieces and fed to 8 replica of my network (assuming 8 GPUs). So in this example, both the h and mask tensor will be chunked into 8 pieces. Eventually, the outputs of the 8 replica are concatenated over the first dimension. Am I understanding this correctly?

Is transposing the input the recommended way of dealing with module that expect input whose first dimension is not the batch dimension. Is it recommended to use contiguous() to improve performance, or is that unnecessary?

Should it be nn.DataParallel(mod.to(device)) or nn.DataParallel(mod).to(device)? Both seem to work but the doc says: " The parallelized
module
must have its parameters and buffers ondevice_ids[0]
before running this DataParallel module." So I don’t understand how come nn.DataParallel(mod).to(device) work?
Thanks!