My problem is how can I generate attention maps for 3D grayscale MRI data after training with vision transformer for a classification problem?

My data shape is (120,120,120) and the model is 3D ViT. For example:

```
img = nib.load()
img = torch.from_numpy(img)
model = torch.load...
model.eval()
output, attn = model(img)
```

After this, because I have 6 transformer layers and 12 heads, so the attn I got the shape that is

```
(6,12,65,65)
```

Then I don’t know how to apply this to original 3D grayscale images. I got several examples online that only deal with images from ImageNet.

For example:

Can anyone help me with this?