Bert classifier and integrated gradient using Captum

Hi,
I’m going to add interpretability to my model using the integrated gradient method of the Captum library.
In the following, I will represent my codes:

from transformers import BertModel

class Bert(nn.Module):
  def __init__(self):
    super(Bert,self).__init__()
    self.bert = BertModel.from_pretrained('bert-base-uncased')
    self.drop = nn.Dropout(0.5)
    self.fc = nn.Linear(768,5) 
  def forward(self,ids,masks):
    _ , cls = self.bert(input_ids = ids, attention_mask = masks, return_dict = False) 
    y = self.drop(cls)
    y = self.fc(y)
    return y

model = Bert()

Now, using tokenizer, I extract ids and masks of a given text:

from transformers import BertTokenizer 
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
A = ['this is an example','this is not an example']
tokens = tokenizer(A, max_length = 512, padding = 'max_length', truncation = True, return_tensors = 'pt')
dataset = TensorDataset(tokens.input_ids, tokens.attention_mask)
data_loader = DataLoader(dataset, batch_size = 1)

Well, I created a data loader to load data one by one.
Now, it is time to set Captum! This is how I am trying to run that (in which i and j is read through data loader):

def predict(x,y):
  y = model(ids = x.long(), masks = y.long())
  return F.softmax(y, dim=-1)

integrated_gradients = IntegratedGradients(predict)
attributions_ig = integrated_gradients.attribute((i,j), target=0)

Now, My questions are:

  1. Am I setting the Captum correctly? Is everything right?!
  2. When I implement these codes, it gives rise to this error:
    RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.
    How should I fix that?

Many thanks in advance.

Have you solve this problem?
I also want to apply Integrated Gradients on custom Bert classification model.

Any updates? Thx!

For the guys who also have a same problem, read this captum tutorial.

It’s not classification problem, but the underlying mechanism is same I think.

Also I read this post and this one for some insights.

Good Luck to your research. :grin: