Outputs from nn.DataParallel are not concatenated

I’ve found similar problems in DataParallel does not concat outputs from multi-gpu and Dimension problem by multiple GPUs. However, their problems are in fact related with unexpected model behaviours (for example, shape-changing masks in multi-head attention models) rather than problem of data concatenation.

Here is my case:

net0 = nn.DataParallel(Encoder())
class Encoder(nn.Module):
…(model init and layers ommited)

def forward(self, x, gene_expression):
    def run(x):
        print(type(x))
        print("x shape:",x.shape)
        lout1 = self.lconv1(x)
        out1 = self.conv1(lout1)
        lout2 = self.lconv2(out1 + lout1)
        out2 = self.conv2(lout2)
        ... (further layers and operations omitted)
        return final_out

    out = run(x)
    print("out shape:",out.shape)
    return out


print(“sequence shape:”, sequence.shape)
encoding0 = net0(torch.Tensor(sequence.float()).transpose(1, 2).cuda())
print(“encoding0 shape:”, encoding0.shape)

I tried using single GPU (without DataParallel) and multiple GPUs. With out DataParallel, the batch sizes are as expected. For example, with sequence shape: torch.Size([16, 2000000, 4]), x shape would be [16, 4, 2000000] and both the out shape and encoding0 shape are [16, 128, 500]. However, when using 8 GPUs, the batch sizes of x, out and encoding0 are all 2.

So it seems to me that DataParallel could correctly split my data but cannot concat them. My input and outputs are purely torch tensors which should not affect the behaviour of DataParallel.

Is there any advice on how to address this issue? Thank you in advance for your help!

I cannot reproduce the issue using:

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(4, 128, 3, 1, 1)

    def forward(self, x):
        def run(x):
            print("x shape:",x.shape)
            out = self.conv1(x)
            return out
        out = run(x)
        print("out shape:",out.shape)
        return out

device = "cuda"
model = Encoder().to(device)
x = torch.randn(16, 4, 500, device=device)

print("single")
out = model(x)
print(out.shape)

print("DP")
model = nn.DataParallel(model)
out = model(x)
print(out.shape)

and see:

single
x shape: torch.Size([16, 4, 500])
out shape: torch.Size([16, 128, 500])
torch.Size([16, 128, 500])
DP
x shape: torch.Size([2, 4, 500])
x shape: torch.Size([2, 4, 500])
x shape: torch.Size([2, 4, 500])
x shape: torch.Size([2, 4, 500])
x shape: torch.Size([2, 4, 500])
x shape: torch.Size([2, 4, 500])
x shape: torch.Size([2, 4, 500])
x shape: torch.Size([2, 4, 500])
out shape: torch.Size([2, 128, 500])
out shape: torch.Size([2, 128, 500])
out shape: torch.Size([2, 128, 500])
out shape: torch.Size([2, 128, 500])
out shape: torch.Size([2, 128, 500])
out shape: torch.Size([2, 128, 500])
out shape: torch.Size([2, 128, 500])
out shape: torch.Size([2, 128, 500])
torch.Size([16, 128, 500])

Thank you for your reply! Indeed, using your sample code I get the same results as you. So this outcome is introduced by the multi-head attention layer I omitted. Since I did not use a mask in my multi-head attention layer, I thought it would not have an effect but it turns out the opposite.

Below is the reproducible code, though it may be a bit complex:

import torch
from torch import nn
import math

class Encoder1(nn.Module):
    def __init__(self):
        super(Encoder1, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv1d(4, 64, kernel_size=9, padding=4),
            nn.BatchNorm1d(64),
            nn.Conv1d(64, 64, kernel_size=9, padding=4),
            nn.BatchNorm1d(64),
        )

    def forward(self, x):
        def run(x):
            print("x shape:",x.shape)
            out = self.conv(x)
            return out

        out = run(x)
        print("out shape:",out.shape)
        return out

class PrepareForMultiHeadAttention(nn.Module):
    def __init__(self, channel: int, heads: int, channel_per_head: int, bias: bool):
        super().__init__()
        self.linear = nn.Linear(channel, channel, bias=bias)
        self.heads = heads
        self.channel_per_head = channel_per_head

    def forward(self, x: torch.Tensor):
        # Input sequence encoding has shape [batch_size, channel, seq_len]
        x = x.permute(2, 0, 1)
        # [seq_len, batch_size, channel]
        head_shape = x.shape[:-1]
        x = self.linear(x)
        x = x.view(*head_shape, self.heads, self.channel_per_head)
        # [seq_len, batch_size, heads, channel_per_head]
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, heads: int, sequence_length: int, sequence_channel: int, gene_length: int, dropout_prob: float = 0.1, bias: bool = True):
        super().__init__()
        assert (sequence_channel % heads == 0), "channel number need to be divisible by heads"
        self.heads = heads
        self.sequence_length = sequence_length
        self.gene_length = gene_length
        self.sequence_channel = sequence_channel
        self.channel_per_head = sequence_channel // heads
        
        self.gene_switch = nn.Linear(gene_length, sequence_length, bias=bias)
        self.query = PrepareForMultiHeadAttention(sequence_channel, heads, self.channel_per_head, bias=bias)
        self.key = PrepareForMultiHeadAttention(sequence_channel, heads, self.channel_per_head, bias=bias)
        self.value = PrepareForMultiHeadAttention(sequence_channel, heads, self.channel_per_head, bias=bias)
        
        self.softmax = nn.Softmax(dim=1)
        self.output = nn.Linear(sequence_channel, sequence_channel)
        self.dropout = nn.Dropout(dropout_prob)
        self.scale = 1 / math.sqrt(self.channel_per_head)
        
        self.norm1 = nn.LayerNorm(sequence_channel)
        self.norm2 = nn.LayerNorm(sequence_channel)
        
        self.ffn = nn.Sequential(
            nn.Linear(sequence_channel, 4 * sequence_channel),
            nn.ReLU(),
            nn.Linear(4 * sequence_channel, sequence_channel),
        )
        
    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
        return torch.einsum('ibhd,jbhd->ijbh', query, key)
        
    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
        # query has shape [1, gene_length]
        # key and value have shape [batch_size, channel, seq_len]
        batch_size, channel, seq_len = key.shape
        assert seq_len == self.sequence_length
        assert channel == self.sequence_channel
        
        query = self.gene_switch(query)
        query = query.reshape([1, 1, seq_len]).repeat_interleave(batch_size, dim=0).repeat_interleave(channel, dim=1)
        query = self.query(query)
        key = self.key(key)
        value2 = self.value(value)

        scores = self.get_scores(query, key)
        scores *= self.scale
        attn = self.softmax(scores)
        attn = self.dropout(attn)
        x = torch.einsum("ijbh,jbhd->ibhd", attn, value2).reshape(seq_len, batch_size, -1) # [seq_len, batch_size, channel]
        x = self.norm1(x + value.permute(2, 0, 1))
        x2 = self.ffn(x)
        x = self.norm2(x + x2)
        x = x.permute(1, 2, 0)  # [batch_size, channel, seq_len]
        return x

class Encoder2(nn.Module):
    def __init__(self, heads: int, gene_length: int):
        super(Encoder2, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv1d(4, 64, kernel_size=9, padding=4),
            nn.BatchNorm1d(64),
            nn.Conv1d(64, 64, kernel_size=9, padding=4),
            nn.BatchNorm1d(64),
        )

        self.attn = MultiHeadAttention(heads=heads, sequence_length=500, sequence_channel=64, gene_length=gene_length)

    def forward(self, x, gene):
        def run(x):
            print("x shape:",x.shape)
            out = self.conv(x)
            attnout = self.attn(query=gene, key=out, value=out)
            return attnout

        out = run(x)
        print("out shape:",out.shape)
        return out


model = Encoder1().cuda()
x = torch.randn(16, 4, 500).cuda()
print("encoder1 single")
out = model(x)
print(out.shape)
print("encoder1 DP")
model = nn.DataParallel(model)
out = model(x)
print(out.shape)


model = Encoder2(heads=8, gene_length=5).cuda()
x = torch.randn(16, 4, 500).cuda()
gene = torch.randn(1, 5).cuda()
print("encoder2 single")
out = model(x, gene)
print(out.shape)
print("encoder2 DP")
model = nn.DataParallel(model)
out = model(x, gene)
print(out.shape)

Its output is:

encoder1 single
x shape: torch.Size([16, 4, 500])
out shape: torch.Size([16, 64, 500])
torch.Size([16, 64, 500])
encoder1 DP
x shape: torch.Size([2, 4, 500])
x shape: torch.Size([2, 4, 500])
x shape: torch.Size([2, 4, 500])
x shape: torch.Size([2, 4, 500])
x shape: torch.Size([2, 4, 500])
x shape: torch.Size([2, 4, 500])
x shape: torch.Size([2, 4, 500])
x shape: torch.Size([2, 4, 500])
out shape: torch.Size([2, 64, 500])
out shape: torch.Size([2, 64, 500])
out shape: torch.Size([2, 64, 500])
out shape: torch.Size([2, 64, 500])
out shape: torch.Size([2, 64, 500])
out shape: torch.Size([2, 64, 500])
out shape: torch.Size([2, 64, 500])
out shape: torch.Size([2, 64, 500])
torch.Size([16, 64, 500])
encoder2 single
x shape: torch.Size([16, 4, 500])
out shape: torch.Size([16, 64, 500])
torch.Size([16, 64, 500])
encoder2 DP
x shape: torch.Size([2, 4, 500])
out shape: torch.Size([2, 64, 500])
torch.Size([2, 64, 500])

I am continuing to look for the reason behind these results. In the meantime, could you please help me see what operations might be causing DP not to behave as expected? Thank you very much for your kind help!

After checking the shapes of all the tensors, I realized that broadcasting was used for the query tensor in torch.einsum. The query tensor is an additional input and it needs to be repeated to have a batch size that can be divided across the number of GPUs. So this problem has the same solution as mentioned here. Apologies for any confusion!