I am looking for clarification on the best way to use DataParallel with attention layers. As an example, MultiheadAttention expects inputs which have shape (L,N,E) where L is the length of the sequence, N is the batchsize, and E is the embedding size. The fact that the batch size is NOT the first dimension leads to problem when using DataParallel. To work around this I am transposing the dimension, see example below:
import torch import torch.nn as nn class AttnParallel(nn.Module): def __init__(self, dim, num_heads): super(AttnParallel,self).__init__() self.attn = nn.MultiheadAttention(dim, num_heads, dropout=0, bias=False) def forward(self, h, mask): print("h has shape:", h.shape) print("mask has shape", mask.shape) h = h.transpose(0,1).contiguous() h = self.attn(h,h,h, key_padding_mask=mask) h = h.transpose(0,1).contiguous() return h # create model dim =4 num_head=2 device = torch.device("cuda") mod = AttnParallel(dim, num_head) mod = nn.DataParallel(mod.to(device)) # create data bsz = 16 L = 5 h = torch.rand(bsz,L,dim) mask = torch.zeros(bsz,L).bool() mask[0,1] = True mask[2,4] = True # forward h = mod(h,mask)
I have a few questions:
My understanding is that when using DataParralel, whatever tensors I feed to the forward() function will be chunked over the first dimension into 8 pieces and fed to 8 replica of my network (assuming 8 GPUs). So in this example, both the h and mask tensor will be chunked into 8 pieces. Eventually, the outputs of the 8 replica are concatenated over the first dimension. Am I understanding this correctly?
Is transposing the input the recommended way of dealing with module that expect input whose first dimension is not the batch dimension. Is it recommended to use contiguous() to improve performance, or is that unnecessary?
Should it be nn.DataParallel(mod.to(device)) or nn.DataParallel(mod).to(device)? Both seem to work but the doc says: " The parallelized
modulemust have its parameters and buffers on
device_idsbefore running this DataParallel module." So I don’t understand how come nn.DataParallel(mod).to(device) work?