I want to visualize the attention maps on the vision transformer. As a sample code, I want to register a hook after the self-attention block(line no 100 below). The code from the time repo is produced in the screenshot below. Since the dropout layer functions as the nn.Identity
layer in eval mode, the self.attn_drop
output or input should be the same, and a hook placed on the dropout layer self.attn_drop
should serve the purpose. Here is the minimum working example to test.
import timm
import torch
# Prepare Model
vit_model = timm.create_model('vit_large_patch16_384', pretrained=True).to("cpu")
vit_model.eval();
x = torch.randn(3, 384, 384)
x.shape
attentions = []
def get_attention(module, input, output):
attentions.append(output.detach().clone())
attentions.append(input[0].detach().clone())
model = vit_model.eval().blocks[0].attn.attn_drop
print(model)
hook = model.register_forward_hook(get_attention)
print(model._forward_hooks)
with torch.no_grad():
output = model.eval()(x.unsqueeze(0))
print(attentions[0].shape,attentions[1].shape)
#(torch.Size([1, 3, 384, 384]), torch.Size([1, 3, 384, 384]))
hook.remove()
attentions = []
print(model._forward_hooks)
This way the hook was getting registered properly. However, the input and output shape of the dropout layer are the same but incorrect. The correct shape should be (batch_size, num_head, num_token, num_token)
, i.e., (1, 16, 577, 577). However, the hook function gives the shape of [1, 3, 384, 384]. What am I missing?
Also, if I register the hook using a for loop for all the sub-modules of the attention block, then it is possible to register a hook on the dropout layer, but still, the list is empty, and therefore the “list index is out of range.”
Are these two things related? Does the dropout layer have some specific behavior?
import timm
import torch
# Prepare Model
vit_model = timm.create_model('vit_large_patch16_384', pretrained=True).to("cpu")
vit_model.eval();
x = torch.randn(3, 384, 384)
x.shape
class VITAttention:
def __init__(self, model, attention_layer_name='attn_drop'):
self.model = model
self.attentions = []
self.hook = []
for name, module in self.model.named_modules():
if attention_layer_name in name:
self.hook.append(module.register_forward_hook(self.get_attention))
def get_attention(self, module, input, output):
self.attentions.append(output.detach().clone())
def __enter__(self, *args): return self
def __exit__(self, *args):
for handle in self.hook:
handle.remove()
with VITAttention(vit_model) as hook:
with torch.no_grad():
output = hook.model.eval()(x.unsqueeze(0))
print("Shape of qkv attention " + str(hook.attentions[0].shape))
print("Shape of key attention " + str(hook.attention[1].shape))