Getting the vision transformer attention matrix

Hi everyone,

I am asking about how to get the attention matrix from a pre-trained vision transformer. I tried to do some work around, yet I need to discuss with someone if this correct or wrong.

def hook(module, input, output):
    """
    A hook function to capture the output of a specific layer during the forward pass.

    This function is attached to a model layer and is called every time the layer processes input.
    It captures the output of the layer and appends it to a global list `features`.

    Args:
        module (torch.nn.Module): The layer to which this hook is attached.
        input (tuple): The input to the `module`. This is a tuple of tensors representing the inputs to the module.
        output (torch.Tensor): The output from the `module`. This is the tensor produced by the module as output.

    """
    # This is where the output of the layer is captured and appended to the global list `features`.
    features.append(output)

# Attach the hook to the layer of interest
# Since we want to see the visualization of the last layer in the encoder (layer 11) we will attach it to the hook
#The last component generating features for the the self-attention in encoder layer 11 is ln_1 (LayerNormalization), so we extract the features from
handle = m.encoder.layers.encoder_layer_11.ln_1.register_forward_hook(hook)
m.eval()

# Pass the image through the model at the inference mode
with torch.no_grad():
    outputs = m(input_batch)
handle.remove()

then I do this:

m.eval()

# Pass the image through the model
with torch.no_grad():
    # We get the output
    output = m(input_batch)
    # print(output.shape)

    # So here we get the weights and biases for the quer, key, and value
    qkv_w = m.encoder.layers.encoder_layer_11.self_attention.in_proj_weight
    qkv_b = m.encoder.layers.encoder_layer_11.self_attention.in_proj_bias

    print(f"The shape of qkv weight matrix before reshaping is {qkv_w.shape}\n")
    print(f"The shape of qkv bias matrix before reshaping is {qkv_b.shape}\n")
    # print(qkv_w.shape)
    """we have shape of (2304 * 768), we need to understand what is the meaning of the dimensions we have?
    first of all, the 768 represnets the D-hidden dimension through the encoder of the vision transformer which is fiexd across all of the encoder network.
    2304 is a little bit tricky and you need to check the original paper to understand why the shape looks like that.

    We have 3 components (query, keys, and values) for each head, and at the encoder (Architecture dependent) we have 12 heads, then we explore this as first divide 2304 by 12 to get dimensions for each head = 2304/12 = 192, here remember that we have 3 matrices stacked so 192/3 = 64, 
    which is the dimension of the head mentioned in the paper as D_{h} = D/k, and K is the number of heads which is 12 for the vit_b_16()"""

    #shape here is (matrices, d_head *k, d_hidden)
    qkv_w = qkv_w.reshape(3, -1, 768)
    qkv_b = qkv_b.reshape(12, -1, 64)

    print(f"The shape of qkv weight matrix after reshaping is {qkv_w.shape}\n")
    print(f"The shape of qkv bias matrix after reshaping is {qkv_b.shape}\n")

    "Here we get the weights and biases for each component for all of the heads"
    
    #shape here for each weight component is (d_head *k, d_hidden)
    q_w_12_heads = qkv_w[0,:,:]
    k_w_12_heads = qkv_w[1,:,:]
    v_w_12_heads = qkv_w[2,:,:]

    

    q_b_12_heads = qkv_b[:,0,:]
    k_b_12_heads = qkv_b[:,1,:]
    v_b_12_heads = qkv_b[:,2,:]


    print(f"The shape of query weight matrix before reshaping is {q_w_12_heads.shape}, key weight is {k_w_12_heads.shape}, and values weight is {v_w_12_heads.shape}\n")
    print(f"The shape of query bias matrix before reshaping is {q_b_12_heads.shape}, key bias is {k_b_12_heads.shape}, and values bias is {v_b_12_heads.shape}\n")

    # Shape here is (no.head, d_head, d_hidden)
    q_w_12_heads = q_w_12_heads.reshape(12, -1, 768)
    k_w_12_heads = k_w_12_heads.reshape(12, -1, 768)
    v_w_12_heads = v_w_12_heads.reshape(12, -1, 768)

    


    
    print(f"The shape of query weight matrix after reshaping is {q_w_12_heads.shape}, key weight is {k_w_12_heads.shape}, and values weight is {v_w_12_heads.shape}\n")
    # Shape here for each weight component is(d_head, d_hidden)
    q_w_1_head = q_w_12_heads[0,:,:]
    k_w_1_head = k_w_12_heads[0,:,:]
    v_w_1_head = v_w_12_heads[0,:,:]

    q_b_1_head = q_b_12_heads[0,:]
    k_b_1_head = k_b_12_heads[0,:]
    v_b_1_head = v_b_12_heads[0,:]

    print(f"The shape of query weight matrix after reshaping for one head is {q_w_1_head.shape}, key weight is {k_w_1_head .shape}, and values weight is {v_w_1_head .shape}\n")
    print(f"The shape of query bias matrix after reshaping for one head is {q_b_1_head.shape}, key bias is {k_b_1_head .shape}, and values bias is {v_b_1_head .shape}\n")


    out_encoder_10 = features[0][0]
    out_encoder_10 = out_encoder_10.unsqueeze(0)
    # print(out_encoder_10.shape)


    # place holder to get the attention weights from the heads to use it for later calculations
    att_weights =[]
    satt = []

    # This loop is created to loop over the heads, in order to get all of the attention matrices (qk^{T}) per heads
    for i in range(12):
        q_w = q_w_12_heads[i,:,:]
        k_w = k_w_12_heads[i,:,:]
        v_w = v_w_12_heads[i,:,:]

        q_b = q_b_12_heads[i,:]
        k_b = k_b_12_heads[i,:]
        v_b = v_b_12_heads[i,:]

        

        q = torch.matmul(out_encoder_10, q_w.T) 
        k = torch.matmul(out_encoder_10, k_w.T) 
        v = torch.matmul(out_encoder_10, v_w.T) 

        qk = torch.matmul(q, k.transpose(2, 1))/8
        qk = torch.softmax(qk, dim=(2))
        # print(qk.shape)
        att_weights.append(qk)

is this correct?

That is a great question, but can you share maybe part of your model codes? I am not sure how the hook function works in your model, which might be important for our discussion

1 Like

Hello Mohamed. I am curious if you got the proper implementation for the pre-trained Vision Transformer (from Torchvision). I would be glad to see the final implementation if it is possible. Thank you.

1 Like

Hello Rauf,

Yeah It is working currently, yet due to being an ongoing research once there is a publication I would put the code and sent it back to you.

Greetings!

Yeah sure!
Yet it is still an ongoing research