Multi-head attention complexity

I’m using the Transformer encoder to make a time series prediction. For the multi-head attention part, I assume the complexity of the model using different heads is the same since the d will split into the h part correspondingly. However, it runs more time to use more heads

1head: 2:29; 4 head: 2:49; 8 head:3:18 ; 16 head: 4:08

Can anyone explain it?

Could you share necessary/executable part of your code?
So, we are sure and replicate your result?

Thanks for the kind reply. I implement it as a standard Transformer encoder, but more heads usually require much more time to run. The backbone of the model is shown below:

class TokenEmbedding(nn.Module):

def __init__(self, input_features, d_model):
    super(TokenEmbedding, self).__init__()
    self.tokenConv = nn.Linear(input_features, d_model)

def forward(self, x):
    x = self.tokenConv(x)
    return x

class PositionalEncoding(nn.Module):

def __init__(self, d_model, seq_len):
    super(PositionalEncoding, self).__init__()
    self.d_model = d_model
    self.dropout = nn.Dropout(p=0.1)
    pe = torch.zeros(seq_len, d_model)

    position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)  # (max_len, 1)
    div_term = torch.exp(
        torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)

    pe = pe.unsqueeze(0)
    self.register_buffer("pe", pe)

def forward(self, x) -> torch.Tensor:
    seq_len = x.shape[1][:, :seq_len].requires_grad_(False)
    return[:, :seq_len]

class PositionWiseFeedForward(nn.Module):

def __init__(self, hidden_size):
    super(PositionWiseFeedForward, self).__init__()
    self.hidden_size = hidden_size

    self.conv = nn.Sequential(
        nn.Linear(hidden_size, hidden_size//4),
        nn.Linear(hidden_size//4, hidden_size),

def forward(self, tensor):
    tensor = self.conv(tensor)

    return tensor

class RegressionModule(nn.Module):

def __init__(self, d_model, seq_len, output_size, dropout_rate, pool='cls'):
    super(RegressionModule, self).__init__()
    self.d_model = d_model
    self.seq_len = seq_len
    self.output_size = output_size
    self.pool = pool
    self.tanh = nn.Tanh()
    self.to_latent = nn.Identity()
    self.fc1 = nn.Linear(seq_len, 1)
    self.fc2 = nn.Linear(self.d_model, output_size)
    self.dropout = nn.Dropout(dropout_rate) 
    self.gelu = nn.GELU()
def forward(self, x):
    x = x.transpose(1,2)   # [N, features,  seq_len]        
    x = self.gelu(self.fc1(self.dropout(x)))
    x= x.contiguous().view(-1, self.d_model)
    x = self.fc2(self.dropout(x))
    return x

class EncoderLayer(nn.Module):

def __init__(self, d_model, n_heads, seq_len, dropout_rate):
    super(EncoderLayer, self).__init__()
    self.d_model = d_model
    self.n_heads = n_heads
    self.dropout_rate = dropout_rate
    self.attn_layer = nn.MultiheadAttention(d_model, n_heads) #, dropout_rate)
    self.attn_batch_norm = nn.BatchNorm1d(d_model)
    self.ff_layer = PositionWiseFeedForward(d_model)
    self.ff_batch_norm = nn.BatchNorm1d(d_model)
    self.batch_norm = nn.BatchNorm1d(d_model)
    self.dropout = nn.Dropout(0)

def forward(self, x, static):           
    x = x.transpose(0,1)              # [seq_len, N, features]           
    output, att = self.attn_layer(x, x, x)  
    output = output.permute(1,2,0)    # [N, features, seq_len]
    x = x.permute(1,2,0)              # [N, features, seq_len] 
    x = self.attn_batch_norm(x+self.dropout(output))
    x1 = self.ff_layer(x.permute(0,2,1))    # [N, seq_len, features]   
    x = self.ff_batch_norm(x+self.dropout(x1.permute(0,2,1)))
    x = x.transpose(1,2)   # [N, seq_len, features]

    return x, att

class Encoder(nn.Module):

def __init__(self, input_features, seq_len, n_heads, n_class, n_layers, d_model, dropout_rate):
    super(Encoder, self).__init__()
    self.input_features = input_features
    self.seq_len = seq_len
    self.n_heads = n_heads
    self.n_class = n_class
    self.n_layers = n_layers
    self.d_model = d_model
    self.sigmoid= nn.Sigmoid()
    self.gelu = nn.GELU() 
    self.dropout_rate = dropout_rate
    self.dropout = nn.Dropout(0.1)
    self.token_embedding = TokenEmbedding(input_features, d_model)
    self.pos_embedding = PositionalEncoding(d_model, seq_len) 
    self.layers = nn.ModuleList([
        EncoderLayer(d_model, n_heads, seq_len, dropout_rate)
        for _ in range(n_layers)
    self.clf = RegressionModule(d_model, seq_len, n_class, dropout_rate) 

def forward(self, x, static):
    x = self.token_embedding(x)
    x += self.pos_embedding(x)
    x = self.dropout(x)
    for layer in self.layers:
        x, att = layer(x, static)  

    x = self.clf(x)

    return x, att

model = Encoder(input_features=5, seq_len=270, n_heads=8, n_class=1, n_layers=1, d_model=256, dropout_rate=0.5).to(DEVICE)