MultiheadAttention for Vision

Can anyone help me in understanding how to use nn.MultiheadAttention for vision and how can I calculate query, key, value if I have (64, 16, 16) intermediate value.

I am Trying to implement Diffusion Models Beat GANs on Image Synthesis paper. They used attention heads in UNets.

Thank you.