Hello everyone,
I would like to extract self-attention maps from a model built around nn.TransformerEncoder.
For simplicity, I omit other elements such as positional encoding and so on. Here is my code snippet.
import torch
import torch.nn as nn
num_heads = 4
num_layers = 3
d_model = 16
# multi-head transformer encoder layer
encoder_layers = nn.TransformerEncoderLayer(
d_model, num_heads, 64, 0.1, norm_first=False, activation="relu", batch_first=True)
# multi-layer transformer encoder
transformer_encoder = nn.TransformerEncoder(
encoder_layers, num_layers)
def extract_selfattention_maps(transformer_encoder,x,mask,src_key_padding_mask):
attention_maps = []
num_layers = transformer_encoder.num_layers
num_heads = transformer_encoder.layers[0].self_attn.num_heads
norm_first = transformer_encoder.layers[0].norm_first
with torch.no_grad():
for i in range(num_layers):
# compute attention of layer i
h = x.clone()
if norm_first:
h = transformer_encoder.layers[i].norm1(h)
attn = transformer_encoder.layers[i].self_attn(h, h, h,attn_mask=mask,key_padding_mask=src_key_padding_mask,need_weights=True)[1]
attention_maps.append(attn)
# forward of layer i
x = transformer_encoder.layers[i](x,src_mask=mask,src_key_padding_mask=src_key_padding_mask)
return attention_maps
batch_size = 8
seq_len = 25
x = torch.randn((batch_size,seq_len,d_model))
src_mask = torch.zeros((seq_len,seq_len)).bool()
src_key_padding_mask = torch.zeros((batch_size,seq_len)).bool()
attention_maps = extract_selfattention_maps(transformer_encoder,x,src_mask,src_key_padding_mask)
First of all, does this look correct please ?
Second, I do not find the source code for F.multi_head_attention_forward in the MultiheadAttention forward so I need some clarifications please.
In the code above, attn is of shape [batch_size,seq_len,seq_len], however there should be num_heads attention maps per layer, how comes ?
And is the attn_output_weights supposed to be Q.K^T the unscaled attention logits or the probabilities Softmax(Q.K^T/sqrt(d)) ?
If possible I would like to access the result of Q.K^T for each head in each layer, any hints please ?
Best wishes to all.