I am using pytorch function for multiheadAttention
to build a vision transformer in a similar fashion with this blog. The function looks as follows:
self.attn = nn.MultiheadAttention(embed_dim, num_heads,dropout=dropout)
Then during the forward pass I am calling it to transform my input embeddings: x = x + self.attn(inp_x, inp_x, inp_x)[0]
What I want is to return also the keys, queries and values to visualize the weights and their gradients. Can I do that with nn.MultiheadAttention
? Or is there any other implementation that could help with that?