Hi,
I’m trying to implement a Deep Explainer for a Resnet50 imported from Torchvision and executed on cifar100. The basic implementation is not working because of an in-place modification not supported by Pytorch (but probably in Keras/TF). Do you know if there is a way to fix this? The error I saw was a warning in the past, but now it stops my code.
Thanks in advance.
Minimal Reproducible Example:
import torch
from torchvision import models, datasets, transforms
import shap
model = models.resnet50(pretrained=True)
dataset = datasets.CIFAR100(root="./",
download=True,
train=True,
transform=transforms.ToTensor())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True, num_workers=1)
for images, labels in dataloader: break
e = shap.DeepExplainer(model, images )
shap_values = e.shap_values( images )
Traceback
Traceback (most recent call last):
File "xai.py", line 339, in <module>
shap_values = e.shap_values(test_images)
File "/usr/local/lib/python3.8/dist-packages/shap/explainers/_deep/__init__.py", line 125, in shap_values
return self.explainer.shap_values(X, ranked_outputs, output_rank_order, check_additivity=check_additivity)
File "/usr/local/lib/python3.8/dist-packages/shap/explainers/_deep/deep_pytorch.py", line 191, in shap_values
sample_phis = self.gradient(feature_ind, joint_x)
File "/usr/local/lib/python3.8/dist-packages/shap/explainers/_deep/deep_pytorch.py", line 107, in gradient
outputs = self.model(*X)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/work/project/models/resnet.py", line 168, in forward
x = self.relu(x) # 32x32
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1128, in _call_impl
result = forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/activation.py", line 98, in forward
return F.relu(input, inplace=self.inplace)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py", line 1440, in relu
result = torch.relu_(input)
RuntimeError: Output 0 of BackwardHookFunctionBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.