Thank you for your reply! Indeed, using your sample code I get the same results as you. So this outcome is introduced by the multi-head attention layer I omitted. Since I did not use a mask in my multi-head attention layer, I thought it would not have an effect but it turns out the opposite.
Below is the reproducible code, though it may be a bit complex:
import torch
from torch import nn
import math
class Encoder1(nn.Module):
def __init__(self):
super(Encoder1, self).__init__()
self.conv = nn.Sequential(
nn.Conv1d(4, 64, kernel_size=9, padding=4),
nn.BatchNorm1d(64),
nn.Conv1d(64, 64, kernel_size=9, padding=4),
nn.BatchNorm1d(64),
)
def forward(self, x):
def run(x):
print("x shape:",x.shape)
out = self.conv(x)
return out
out = run(x)
print("out shape:",out.shape)
return out
class PrepareForMultiHeadAttention(nn.Module):
def __init__(self, channel: int, heads: int, channel_per_head: int, bias: bool):
super().__init__()
self.linear = nn.Linear(channel, channel, bias=bias)
self.heads = heads
self.channel_per_head = channel_per_head
def forward(self, x: torch.Tensor):
# Input sequence encoding has shape [batch_size, channel, seq_len]
x = x.permute(2, 0, 1)
# [seq_len, batch_size, channel]
head_shape = x.shape[:-1]
x = self.linear(x)
x = x.view(*head_shape, self.heads, self.channel_per_head)
# [seq_len, batch_size, heads, channel_per_head]
return x
class MultiHeadAttention(nn.Module):
def __init__(self, heads: int, sequence_length: int, sequence_channel: int, gene_length: int, dropout_prob: float = 0.1, bias: bool = True):
super().__init__()
assert (sequence_channel % heads == 0), "channel number need to be divisible by heads"
self.heads = heads
self.sequence_length = sequence_length
self.gene_length = gene_length
self.sequence_channel = sequence_channel
self.channel_per_head = sequence_channel // heads
self.gene_switch = nn.Linear(gene_length, sequence_length, bias=bias)
self.query = PrepareForMultiHeadAttention(sequence_channel, heads, self.channel_per_head, bias=bias)
self.key = PrepareForMultiHeadAttention(sequence_channel, heads, self.channel_per_head, bias=bias)
self.value = PrepareForMultiHeadAttention(sequence_channel, heads, self.channel_per_head, bias=bias)
self.softmax = nn.Softmax(dim=1)
self.output = nn.Linear(sequence_channel, sequence_channel)
self.dropout = nn.Dropout(dropout_prob)
self.scale = 1 / math.sqrt(self.channel_per_head)
self.norm1 = nn.LayerNorm(sequence_channel)
self.norm2 = nn.LayerNorm(sequence_channel)
self.ffn = nn.Sequential(
nn.Linear(sequence_channel, 4 * sequence_channel),
nn.ReLU(),
nn.Linear(4 * sequence_channel, sequence_channel),
)
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
return torch.einsum('ibhd,jbhd->ijbh', query, key)
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
# query has shape [1, gene_length]
# key and value have shape [batch_size, channel, seq_len]
batch_size, channel, seq_len = key.shape
assert seq_len == self.sequence_length
assert channel == self.sequence_channel
query = self.gene_switch(query)
query = query.reshape([1, 1, seq_len]).repeat_interleave(batch_size, dim=0).repeat_interleave(channel, dim=1)
query = self.query(query)
key = self.key(key)
value2 = self.value(value)
scores = self.get_scores(query, key)
scores *= self.scale
attn = self.softmax(scores)
attn = self.dropout(attn)
x = torch.einsum("ijbh,jbhd->ibhd", attn, value2).reshape(seq_len, batch_size, -1) # [seq_len, batch_size, channel]
x = self.norm1(x + value.permute(2, 0, 1))
x2 = self.ffn(x)
x = self.norm2(x + x2)
x = x.permute(1, 2, 0) # [batch_size, channel, seq_len]
return x
class Encoder2(nn.Module):
def __init__(self, heads: int, gene_length: int):
super(Encoder2, self).__init__()
self.conv = nn.Sequential(
nn.Conv1d(4, 64, kernel_size=9, padding=4),
nn.BatchNorm1d(64),
nn.Conv1d(64, 64, kernel_size=9, padding=4),
nn.BatchNorm1d(64),
)
self.attn = MultiHeadAttention(heads=heads, sequence_length=500, sequence_channel=64, gene_length=gene_length)
def forward(self, x, gene):
def run(x):
print("x shape:",x.shape)
out = self.conv(x)
attnout = self.attn(query=gene, key=out, value=out)
return attnout
out = run(x)
print("out shape:",out.shape)
return out
model = Encoder1().cuda()
x = torch.randn(16, 4, 500).cuda()
print("encoder1 single")
out = model(x)
print(out.shape)
print("encoder1 DP")
model = nn.DataParallel(model)
out = model(x)
print(out.shape)
model = Encoder2(heads=8, gene_length=5).cuda()
x = torch.randn(16, 4, 500).cuda()
gene = torch.randn(1, 5).cuda()
print("encoder2 single")
out = model(x, gene)
print(out.shape)
print("encoder2 DP")
model = nn.DataParallel(model)
out = model(x, gene)
print(out.shape)
Its output is:
encoder1 single
x shape: torch.Size([16, 4, 500])
out shape: torch.Size([16, 64, 500])
torch.Size([16, 64, 500])
encoder1 DP
x shape: torch.Size([2, 4, 500])
x shape: torch.Size([2, 4, 500])
x shape: torch.Size([2, 4, 500])
x shape: torch.Size([2, 4, 500])
x shape: torch.Size([2, 4, 500])
x shape: torch.Size([2, 4, 500])
x shape: torch.Size([2, 4, 500])
x shape: torch.Size([2, 4, 500])
out shape: torch.Size([2, 64, 500])
out shape: torch.Size([2, 64, 500])
out shape: torch.Size([2, 64, 500])
out shape: torch.Size([2, 64, 500])
out shape: torch.Size([2, 64, 500])
out shape: torch.Size([2, 64, 500])
out shape: torch.Size([2, 64, 500])
out shape: torch.Size([2, 64, 500])
torch.Size([16, 64, 500])
encoder2 single
x shape: torch.Size([16, 4, 500])
out shape: torch.Size([16, 64, 500])
torch.Size([16, 64, 500])
encoder2 DP
x shape: torch.Size([2, 4, 500])
out shape: torch.Size([2, 64, 500])
torch.Size([2, 64, 500])
I am continuing to look for the reason behind these results. In the meantime, could you please help me see what operations might be causing DP not to behave as expected? Thank you very much for your kind help!