Extracting self-attention maps from nn.TransformerEncoder

Hello everyone,

I would like to extract self-attention maps from a model built around nn.TransformerEncoder.
For simplicity, I omit other elements such as positional encoding and so on. Here is my code snippet.

import torch
import torch.nn as nn

num_heads = 4
num_layers = 3
d_model = 16

# multi-head transformer encoder layer
encoder_layers = nn.TransformerEncoderLayer(
    d_model, num_heads, 64, 0.1, norm_first=False, activation="relu", batch_first=True)
# multi-layer transformer encoder
transformer_encoder = nn.TransformerEncoder(
    encoder_layers, num_layers)


def extract_selfattention_maps(transformer_encoder,x,mask,src_key_padding_mask):
    attention_maps = []
    num_layers = transformer_encoder.num_layers
    num_heads = transformer_encoder.layers[0].self_attn.num_heads
    norm_first = transformer_encoder.layers[0].norm_first
    with torch.no_grad():
        for i in range(num_layers):
            # compute attention of layer i
            h = x.clone()
            if norm_first:
                h = transformer_encoder.layers[i].norm1(h)
            attn = transformer_encoder.layers[i].self_attn(h, h, h,attn_mask=mask,key_padding_mask=src_key_padding_mask,need_weights=True)[1]
            attention_maps.append(attn)
            # forward of layer i
            x = transformer_encoder.layers[i](x,src_mask=mask,src_key_padding_mask=src_key_padding_mask)
    return attention_maps


batch_size = 8
seq_len = 25

x = torch.randn((batch_size,seq_len,d_model))

src_mask = torch.zeros((seq_len,seq_len)).bool()
src_key_padding_mask = torch.zeros((batch_size,seq_len)).bool()

attention_maps = extract_selfattention_maps(transformer_encoder,x,src_mask,src_key_padding_mask)

First of all, does this look correct please ?

Second, I do not find the source code for F.multi_head_attention_forward in the MultiheadAttention forward so I need some clarifications please.

In the code above, attn is of shape [batch_size,seq_len,seq_len], however there should be num_heads attention maps per layer, how comes ?

And is the attn_output_weights supposed to be Q.K^T the unscaled attention logits or the probabilities Softmax(Q.K^T/sqrt(d)) ?

If possible I would like to access the result of Q.K^T for each head in each layer, any hints please ?

Best wishes to all.

Alternatively, the call of multi_head_attention_forward could be replaced by manually performing the operations in order to get the desired tensors, in the code below it is done by compute_selfattention, derived from the pytorch-lightning tutorial on transformers. However, since I do not have the source code of pytorch’s built-in implementation, I would ask please for some confirmation that the below code would be correct.

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

num_heads = 4
num_layers = 3
d_model = 16
d_head = d_model//num_heads

# multi-head transformer encoder layer
encoder_layers = nn.TransformerEncoderLayer(
    d_model, num_heads, 64, 0.1, norm_first=False, activation="relu", batch_first=True)
# multi-layer transformer encoder
transformer_encoder = nn.TransformerEncoder(
    encoder_layers, num_layers)

def compute_selfattention(transformer_encoder,x,mask,src_key_padding_mask,i_layer,d_model,num_heads):
    h = F.linear(x, transformer_encoder.layers[i_layer].self_attn.in_proj_weight, bias=transformer_encoder.layers[i_layer].self_attn.in_proj_bias)
    qkv = h.reshape(x.shape[0], x.shape[1], num_heads, 3 * d_model//num_heads)
    qkv = qkv.permute(0, 2, 1, 3)  # [Batch, Head, SeqLen, Dims]
    q, k, v = qkv.chunk(3, dim=-1) # [Batch, Head, SeqLen, d_head=d_model//num_heads]
    attn_logits = torch.matmul(q, k.transpose(-2, -1)) # [Batch, Head, SeqLen, SeqLen]
    d_k = q.size()[-1]
    attn_probs = attn_logits / math.sqrt(d_k)
    # combining src_mask e.g. upper triangular with src_key_padding_mask e.g. columns over each padding position
    combined_mask = torch.zeros_like(attn_probs)
    if mask is not None:
        combined_mask += mask.float() # assume mask of shape (seq_len,seq_len)
    if src_key_padding_mask is not None:
        combined_mask += src_key_padding_mask.float().unsqueeze(1).unsqueeze(1).repeat(1,num_heads,x.shape[1],1)
        # assume shape (batch_size,seq_len), repeating along head and line dimensions == "column" mask
    combined_mask = torch.where(combined_mask>0,torch.zeros_like(combined_mask)-float("inf"),torch.zeros_like(combined_mask))
    # setting masked logits to -inf before softmax
    attn_probs += combined_mask
    attn_probs = F.softmax(attn_probs, dim=-1)
    return attn_logits,attn_probs

def extract_selfattention_maps(transformer_encoder,x,mask,src_key_padding_mask):
    attn_logits_maps = []
    attn_probs_maps = []
    num_layers = transformer_encoder.num_layers
    d_model = transformer_encoder.layers[0].self_attn.embed_dim
    num_heads = transformer_encoder.layers[0].self_attn.num_heads
    norm_first = transformer_encoder.layers[0].norm_first
    with torch.no_grad():
        for i in range(num_layers):
            # compute attention of layer i
            h = x.clone()
            if norm_first:
                h = transformer_encoder.layers[i].norm1(h)
            # attn = transformer_encoder.layers[i].self_attn(h, h, h,attn_mask=mask,key_padding_mask=src_key_padding_mask,need_weights=True)[1]
            # attention_maps.append(attn) # of shape [batch_size,seq_len,seq_len]
            attn_logits,attn_probs = compute_selfattention(transformer_encoder,h,mask,src_key_padding_mask,i,d_model,num_heads)
            attn_logits_maps.append(attn_logits) # of shape [batch_size,num_heads,seq_len,seq_len]
            attn_probs_maps.append(attn_probs)
            # forward of layer i
            x = transformer_encoder.layers[i](x,src_mask=mask,src_key_padding_mask=src_key_padding_mask)
    return attn_logits_maps,attn_probs_maps


batch_size = 8
seq_len = 25

x = torch.randn((batch_size,seq_len,d_model))

src_mask = torch.zeros((seq_len,seq_len)).bool()
# src_mask can also be of shape (N⋅num_heads,L,S) where L/S are the target/source sequence lengths
# -> allows different masking per element in the batch
src_key_padding_mask = torch.zeros((batch_size,seq_len)).bool()

attn_logits_maps,attn_probs_maps = extract_selfattention_maps(transformer_encoder,x,src_mask,src_key_padding_mask)

Thanks in advance for any corrections or pointing me other options I did not see for getting attention maps for the nn.TransformerEncoder.