I’m working on an audio recognition task using a Transformer-based model in PyTorch. My input features are generated by a CNN-based embedding layer and have the shape [batch_size, d_model, n_token], where n_token is the sequence length and d_model is the feature dimension.
By default, nn.MultiheadAttention (when batch_first=False) expects input in the shape (seq, batch, feature). To make things more intuitive, I chose to set batch_first=True and then permute my data from [batch_size, d_model, n_token] to [batch_size, n_token, d_model] so that the time dimension comes before the feature dimension. Here’s a simplified code snippet:
# Original shape: [batch_size, d_model, n_token]
data = concat_cls_token(data) # [batch_size, d_model, n_token+1]
data = data.permute(0, 2, 1) # [batch_size, n_token+1, d_model]
multihead_att = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
data, _ = multihead_att(data, data, data)
# Result shape: [batch_size, n_token+1, d_model]
After applying multi-head attention, I use LayerNorm(d_model) directly on this [batch_size, n_token+1, d_model] tensor. My understanding is that LayerNorm normalizes over the feature dimension, so as long as the feature dimension (d_model) is the last one, it should work fine. But I have two main questions:
1. If I had stuck with the default multi-head attention format (seq, batch, feature)—that is, using [n_token+1, batch_size, d_model]—would LayerNorm(d_model) still correctly normalize along the feature dimension without permuting the tensor again?
2. In practice, what’s the best approach for tasks like mine (audio sequence recognition)? Is it recommended to keep the data in [batch_size, seq_len, d_model] format before calling LayerNorm, or is it perfectly acceptable to use (seq, batch, feature) as long as the feature dimension is last?
Both my advisor and I are a bit uncertain. I’d really appreciate any guidance or clarification. Below are more details from my forward method and the corresponding AttentionBlock implementation for reference:
def forward(self, x: torch.Tensor):
# Initial: x is [batch_size, d_model, num_tokens]
x = self.expand(x)
x = self.concat_cls_token(x) # [batch_size, d_model, num_tokens+1]
x = x.permute(0, 2, 1) # [batch_size, num_tokens+1, d_model]
x = self.positional_encoder(x)
x = self.attention_block(x) # [batch_size, num_tokens+1, d_model]
x = x.permute(0, 2, 1) # [batch_size, d_model, num_tokens+1]
x = self.get_cls_token(x) # [batch_size, d_model, 1]
y = self.class_mlp(x) # [batch_size, n_classes]
return y
and the implement of AttentionBlock:
class AttentionBlock(nn.Module):
@staticmethod
def make_ffn(hidden_dim: int) -> torch.nn.Module:
return nn.Sequential(
OrderedDict([
("ffn_linear1", nn.Linear(in_features=hidden_dim, out_features=hidden_dim)),
("ffn_relu", nn.ReLU()),
("ffn_linear2", nn.Linear(in_features=hidden_dim, out_features=hidden_dim))
])
)
def __init__(self, embed_dim, n_head):
super().__init__()
self.attention = nn.MultiheadAttention(embed_dim, n_head, batch_first=True)
self.layer_norm1 = nn.LayerNorm(embed_dim)
self.feed_forward = self.make_ffn(embed_dim)
self.layer_norm2 = nn.LayerNorm(embed_dim)
def forward(self, x: torch.Tensor):
attn_output, _ = self.attention(x, x, x)
x = self.layer_norm1(x + attn_output)
ff_output = self.feed_forward(x)
x = self.layer_norm2(x + ff_output)
return x
Any advice or best practices would be greatly appreciated. Thank you so much!