How to generate vision transformer attention maps for 3D grayscale MRI data

My problem is how can I generate attention maps for 3D grayscale MRI data after training with vision transformer for a classification problem?

My data shape is (120,120,120) and the model is 3D ViT. For example:

img = nib.load()
img = torch.from_numpy(img)
model = torch.load...

output, attn = model(img)

After this, because I have 6 transformer layers and 12 heads, so the attn I got the shape that is


Then I don’t know how to apply this to original 3D grayscale images. I got several examples online that only deal with images from ImageNet.

For example:

Can anyone help me with this?