Uisng Lime Library with Pytorch on the function instance_explainer

Hello there :
I am trying to use the lime.lime_image.LimeImageExplainer().explain_instance function to interpret the results of my trained from scratch Swin Transformer on 100 000 galaxy images. I am not quite sure how the classifier_fn works. the function takes in two required parameters: the image which should be a numpy array and a th classifier_fn which takes that same image, converts it to a tensor and produces the probabilities. The following is snippet of my implemented code and I’d very much appreciate the help on making it work as the issue is high priority and requires immediate attention.
Thank you : )

1- Create a Swin Transformer Model Instance

2- Load the weights to the Swin Instance

3- Read teh image and apply transformations to it

4- Apply XAI.

import torch
import numpy as np
import torch.nn.functional as F
from lime import lime_image
from torchvision import transforms
from PIL import Image
from skimage.segmentation import mark_boundaries
import matplotlib.pyplot as plt
from swin_transformer_pytorch.swin_transformer import SwinTransformer

1 : Model Instantiation

model = SwinTransformer(hidden_dim = 96, layers=(2,2,6,2), heads= (3, 6, 12, 24), channels=3, num_classes= 8, head_dim=32, window_size=7,
downscaling_factors=(4, 2, 2, 2), relative_pos_embedding=True)

2 : Loading the model weights

dictionary_path = “C:\Users\katy99\Desktop\MyCiApp\gz2_datasets\Swin_7_2023_cache.pth”

model.load_state_dict(torch.load(dictionary_path, map_location=torch.device(‘cpu’)))

3 : Preprocessing the test iamge:

image_path = ‘C:\Users\katy99\Desktop\MyCiApp\images_test\49.jpg’
input_size = 224
test_transform = transforms.Compose([transforms.CenterCrop(input_size),
transforms.Normalize([0.094, 0.0815, 0.063], [0.1303, 0.11, 0.0913])])

image = Image.open(image_path)
input_image = test_transform(image).unsqueeze(0)
print (input_image.shape)

#Step 4: Use LIME to generate the XAI:

Crucial1: Reshape the image tensor to a nunpy array for the first parameter,

Which is the input image

#Crucial2 : Reshape the array back to a tensor for the forward function

Requires the writing of my own function.

input_image_numpy = input_image[0].permute(1,2,0).numpy()

def classifier_fn (numpy_image):
torch_image = torch.from_numpy(numpy_image)
with torch.no_grad():
outputs = model (torch_image)
_, predicted_class = torch.max(outputs, 1)
probs = F.softmax(predicted_class, dim=1)
return probs.detach().cpu().numpy()

explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(input_image_numpy, classifier_fn, top_labels=5, num_samples=1000)

temp, mask = explanation.get_image_and_mask(explanation.top_labels[0],

Step 5: Visualize the explanations

Plot the original image with highlighted important regions

plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))