Return queries, keys and values from nn.multiheadAttention

I am using pytorch function for multiheadAttention to build a vision transformer in a similar fashion with this blog. The function looks as follows:

self.attn = nn.MultiheadAttention(embed_dim, num_heads,dropout=dropout)

Then during the forward pass I am calling it to transform my input embeddings: x = x + self.attn(inp_x, inp_x, inp_x)[0]

What I want is to return also the keys, queries and values to visualize the weights and their gradients. Can I do that with nn.MultiheadAttention? Or is there any other implementation that could help with that?