My problem is how can I generate attention maps for 3D grayscale MRI data after training with vision transformer for a classification problem?
My data shape is (120,120,120) and the model is 3D ViT. For example:
img = nib.load()
img = torch.from_numpy(img)
model = torch.load...
model.eval()
output, attn = model(img)
After this, because I have 6 transformer layers and 12 heads, so the attn I got the shape that is
(6,12,65,65)
Then I don’t know how to apply this to original 3D grayscale images. I got several examples online that only deal with images from ImageNet.
For example:
Can anyone help me with this?