Integrated gradients with captum and handmade transformer model

hi! i’m using captum with a transformer based protein language model in order to identify input (embeddings)-output correlations. i take inspiration from captum website tutorials (BERT model) but i’m not able to run last bunch of codes relate to captum.

import numpy as np

import torch

import torch.nn as nn

import torch.nn.functional as F

import torch.utils.data as Data

import torch.nn.utils.rnn as rnn_utils

import os

import time

from sklearn.metrics import auc, roc_curve, average_precision_score, precision_recall_curve

from termcolor import colored

import pdb

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def generate_data(file):
    # Amino acid dictionary
    aa_dict = {'A': 1, 'R': 2, 'N': 3, 'D': 4, 'C': 5, 'Q': 6, 'E': 7, 'G': 8, 'H': 9, 'I': 10,
               'L': 11, 'K': 12, 'M': 13, 'F': 14, 'P': 15, 'O': 16, 'S': 17, 'U': 18, 'T': 19,
               'W': 20, 'Y': 21, 'V': 22, 'X': 23}
    # open csv file 
    with open(file, 'r') as inf:
        lines = inf.read().splitlines()

    pep_codes = []
    labels = []
    peps = []
    
    for pep in lines:                           # for every row
        pep, label = pep.split(",")             # sequence and label split
        peps.append(pep)
        labels.append(int(label))
        current_pep = []
        for aa in pep:
            current_pep.append(aa_dict[aa])
        pep_codes.append(torch.tensor(current_pep))

        

    data = rnn_utils.pad_sequence(pep_codes, batch_first=True)  # Fill the sequence to the same length
  

    return data, torch.tensor(labels)
data, label = generate_data("./SSP_dataset.csv")

train_data, train_label= data[:1894], label[:1894] #I primi 1894 sono usati per il training

test_data, test_label = data[1894:], label[1894:]

train_dataset = Data.TensorDataset(train_data, train_label)

test_dataset = Data.TensorDataset(test_data, test_label)

batch_size = 64

train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_iter = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

class xAInet(nn.Module):
  def __init__(self):
    super().__init__()
    
    self.hidden_dim = 25
    self.batch_size = 32
    self.embedding_dim = 512

    self.embedding_layer = nn.Embedding(24, self.embedding_dim, padding_idx=0)
    self.encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)

    self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=1)
    self.gru = nn.GRU(self.embedding_dim, self.hidden_dim, num_layers=2,
                      bidirectional=True, dropout=.2)
    
    self.block_seq = nn.Sequential(nn.Linear(15050, 2048),
                                   nn.BatchNorm1d(2048),
                                   nn.LeakyReLU(),
                                   nn.Linear(2048, 1024),
                                   nn.BatchNorm1d(1024),
                                   nn.LeakyReLU(),
                                   nn.Linear(1024, 256),
                                   nn.BatchNorm1d(256),
                                   nn.ReLU(),
                                   nn.Linear(256, 8),
                                   nn.Linear(8, 2))
    
    
  def forward(self, seq):
        embeddings = self.embedding_layer(seq)
        output = self.transformer_encoder(embeddings).permute(1, 0, 2)
        output, hn = self.gru(output)
        output = output.permute(1, 0, 2)
        hn = hn.permute(1, 0, 2)
       
        output = output.reshape(output.shape[0], -1)
        hn = hn.reshape(output.shape[0], -1)
        
        output = torch.cat([output, hn], 1)
        output = self.block_seq(output)

        return output, embeddings

  def train_model(self, seq):
    #with torch.no_grad():
        output,_ = self.forward(seq)

        return output

class ContrastiveLoss(torch.nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
                                      label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))

        return loss_contrastive
def collate(batch):
    seq1_ls = []
    seq2_ls = []
    label1_ls = []
    label2_ls = []
    label_ls = []


    batch_size = len(batch)
    for i in range(int(batch_size / 2)):
        seq1, label1= batch[i][0], batch[i][1]
        seq2, label2= batch[i + int(batch_size / 2)][0], \
                                       batch[i + int(batch_size / 2)][1], \
                                       
        label1_ls.append(label1.unsqueeze(0))
        label2_ls.append(label2.unsqueeze(0))
        label = (label1 ^ label2)
        seq1_ls.append(seq1.unsqueeze(0))
        seq2_ls.append(seq2.unsqueeze(0))
        label_ls.append(label.unsqueeze(0))

        

    seq1 = torch.cat(seq1_ls).to(device)
    seq2 = torch.cat(seq2_ls).to(device)

    

    label = torch.cat(label_ls).to(device)
    label1 = torch.cat(label1_ls).to(device)
    label2 = torch.cat(label2_ls).to(device)
    return seq1, seq2, label, label1, label2


train_iter_cont = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                              shuffle=True, collate_fn=collate)

device = torch.device("cuda", 0)


def evaluate(data_iter, net):
    pred_prob = []
    label_pred = []
    label_real = []
    for x, y in data_iter:
        x, y = x.to(device), y.to(device)
        outputs = net.train_model(x)
        outputs_cpu = outputs.cpu()
        y_cpu = y.cpu()
        pred_prob_positive = outputs_cpu[:, 1]
        pred_prob = pred_prob + pred_prob_positive.tolist()
        label_pred = label_pred + outputs.argmax(dim=1).tolist()
        label_real = label_real + y_cpu.tolist()
    performance, roc_data, prc_data = caculate_metric(pred_prob, label_pred, label_real)
    return performance, roc_data, prc_data


def caculate_metric(pred_prob, label_pred, label_real):
    test_num = len(label_real)
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    for index in range(test_num):
        if label_real[index] == 1:
            if label_real[index] == label_pred[index]:
                tp = tp + 1
            else:
                fn = fn + 1
        else:
            if label_real[index] == label_pred[index]:
                tn = tn + 1
            else:
                fp = fp + 1

    # Accuracy
    ACC = float(tp + tn) / test_num

    # Sensitivity
    if tp + fn == 0:
        Recall = Sensitivity = 0
    else:
        Recall = Sensitivity = float(tp) / (tp + fn)

    # Specificity
    if tn + fp == 0:
        Specificity = 0
    else:
        Specificity = float(tn) / (tn + fp)

    # MCC
    if (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) == 0:
        MCC = 0
    else:
        MCC = float(tp * tn - fp * fn) / (np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)))

    # ROC and AUC
    FPR, TPR, thresholds = roc_curve(label_real, pred_prob, pos_label=1)

    AUC = auc(FPR, TPR)

    # PRC and AP
    precision, recall, thresholds = precision_recall_curve(label_real, pred_prob, pos_label=1)
    AP = average_precision_score(label_real, pred_prob, average='macro', pos_label=1, sample_weight=None)

    performance = [ACC, Sensitivity, Specificity, AUC, MCC]
    roc_data = [FPR, TPR, AUC]
    prc_data = [recall, precision, AP]
    return performance, roc_data, prc_data


def to_log(log):
    with open("./results/ExamPle_Log.log", "a+") as f:
        f.write(log + '\n')

net = xAInet().to(device)
lr = 0.0001
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
criterion = ContrastiveLoss()
criterion_model = nn.CrossEntropyLoss(reduction='sum')
best_acc = 0
EPOCH = 5
for epoch in range(EPOCH):
    loss_ls = []
    loss1_ls = []
    loss2_3_ls = []
    t0 = time.time()
    net.train()
    for seq1, seq2, label, label1, label2 in train_iter_cont:
        output1,emb = net(seq1)
        output2,emb = net(seq2)
         
        #pdb.set_trace()
        output3 = net.train_model(seq1)

        output4 = net.train_model(seq2)
        loss1 = criterion(output1, output2, label)
        loss2 = criterion_model(output3, label1)
        loss3 = criterion_model(output4, label2)
        loss = loss1 + loss2 + loss3
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_ls.append(loss.item())
        loss1_ls.append(loss1.item())
        loss2_3_ls.append((loss2 + loss3).item())

    net.eval()
    with torch.no_grad():
        train_performance, train_roc_data, train_prc_data = evaluate(train_iter, net)
        test_performance, test_roc_data, test_prc_data = evaluate(test_iter, net)

    results = f"\nepoch: {epoch + 1}, loss: {np.mean(loss_ls):.5f}, loss1: {np.mean(loss1_ls):.5f}, loss2_3: {np.mean(loss2_3_ls):.5f}\n"
    results += f'train_acc: {train_performance[0]:.4f}, time: {time.time() - t0:.2f}'
    results += '\n' + '=' * 16 + ' Test Performance. Epoch[{}] '.format(epoch + 1) + '=' * 16 \
               + '\n[ACC,\tSE,\t\tSP,\t\tAUC,\tMCC]\n' + '{:.4f},\t{:.4f},\t{:.4f},\t{:.4f},\t{:.4f}'.format(
        test_performance[0], test_performance[1], test_performance[2], test_performance[3],
        test_performance[4]) + '\n' + '=' * 60
    print(results)
    # to_log(results)
    test_acc = test_performance[0]  # test_performance: [ACC, Sensitivity, Specificity, AUC, MCC]
    if test_acc > best_acc:
        best_acc = test_acc
        best_performance = test_performance
        filename = '{}, {}[{:.3f}].pt'.format('ExamPle' + ', epoch[{}]'.format(epoch + 1), 'ACC', best_acc)
        save_path_pt = os.path.join('./Model', filename)
        # torch.save(net.state_dict(), save_path_pt, _use_new_zipfile_serialization=False)
        best_results = '\n' + '=' * 16 + colored(' Best Performance. Epoch[{}] ', 'red').format(epoch + 1) + '=' * 16 \
                       + '\n[ACC,\tSE,\t\tSP,\t\tAUC,\tMCC]\n' + '{:.4f},\t{:.4f},\t{:.4f},\t{:.4f},\t{:.4f}'.format(
            best_performance[0], best_performance[1], best_performance[2], best_performance[3],
            best_performance[4]) + '\n' + '=' * 60
        print(best_results)
        best_ROC = test_roc_data
        best_PRC = test_prc_data

def model_output(inputs):

    inputs = inputs[0].unsqueeze(0)
  
    out, embeddings = model(inputs)
    # Apply softmax to convert prediction scores to probabilities
    probabilities = torch.softmax(out, dim=1)
    

    # Get the predicted classes by selecting the class with the highest probability
    predicted_classes = torch.argmax(probabilities, dim=1)  
    return predicted_classes

import captum

from captum.attr import LayerIntegratedGradients

lig = LayerIntegratedGradients(model_output, model.embedding_layer)

def construct_input_and_baseline(text):

    max_length = 512
    #baseline_token_id = rnn_utils.pad_sequence()
    

    input_ids = []
    token_list = []
    
    aa_dict = {'A': 1, 'R': 2, 'N': 3, 'D': 4, 'C': 5, 'Q': 6, 'E': 7, 'G': 8, 'H': 9, 'I': 10,
               'L': 11, 'K': 12, 'M': 13, 'F': 14, 'P': 15, 'O': 16, 'S': 17, 'U': 18, 'T': 19,
               'W': 20, 'Y': 21, 'V': 22, 'X': 23}
    
    for char in text:
      if char in aa_dict:
        input_ids.append(aa_dict[char])
        token_list.append(char)

    baseline_token_id = 13
    baseline_input_ids = [baseline_token_id] * len(input_ids)

    input_ids_tensor = torch.tensor([input_ids], device='cpu')
    baseline_input_ids_tensor = torch.tensor([baseline_input_ids], device='cpu')

    return input_ids_tensor, baseline_input_ids_tensor, token_list

text = 'MSKSKMLVFKSKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKMSKSKMLVFKMSKSKMLVFKMSKSKMLVFKMSKSKMLVFK'

input_ids, baseline_input_ids, all_tokens = construct_input_and_baseline(text)

print(f'original text: {input_ids}')

print(f'baseline text: {baseline_input_ids}')

print(f'all tokens: {all_tokens}')

desired_length = 299
padded_sequences = [seq[:desired_length] if len(seq) >= desired_length else torch.cat((seq, torch.zeros(desired_length - len(seq)))) for seq in input_ids]
# Apply pad_sequence on the padded_sequences
    #data = pad_sequence(padded_sequences, batch_first=True)
    
input_ids = rnn_utils.pad_sequence(padded_sequences, batch_first=True)

desired_length = 299
padded_sequences = [seq[:desired_length] if len(seq) >= desired_length else torch.cat((seq, torch.zeros(desired_length - len(seq)))) for seq in baseline_input_ids]
# Apply pad_sequence on the padded_sequences
    #data = pad_sequence(padded_sequences, batch_first=True)
    
baseline_input_ids = rnn_utils.pad_sequence(padded_sequences, batch_first=True)

attributions, delta = lig.attribute(inputs= input_ids,
                                    baselines= baseline_input_ids,
                                    return_convergence_delta=True,
                                    internal_batch_size=1
                                    )

i got the error:

RuntimeError                              Traceback (most recent call last)

<ipython-input-42-e20407bb5215> in <cell line: 1>()
----> 1 ig.attribute(inputs=input_ids, baselines=baseline_input_ids, target=0)

9 frames

/usr/local/lib/python3.10/dist-packages/captum/log/__init__.py in wrapper(*args, **kwargs)
     40             @wraps(func)
     41             def wrapper(*args, **kwargs):
---> 42                 return func(*args, **kwargs)
     43 
     44             return wrapper

/usr/local/lib/python3.10/dist-packages/captum/attr/_core/integrated_gradients.py in attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, internal_batch_size, return_convergence_delta)
    284             )
    285         else:
--> 286             attributions = self._attribute(
    287                 inputs=inputs,
    288                 baselines=baselines,

/usr/local/lib/python3.10/dist-packages/captum/attr/_core/integrated_gradients.py in _attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, step_sizes_and_alphas)
    349 
    350         # grads: dim -> (bsz * #steps x inputs[0].shape[1:], ...)
--> 351         grads = self.gradient_func(
    352             forward_fn=self.forward_func,
    353             inputs=scaled_features_tpl,

/usr/local/lib/python3.10/dist-packages/captum/_utils/gradient.py in compute_gradients(forward_fn, inputs, target_ind, additional_forward_args)
    110     with torch.autograd.set_grad_enabled(True):
    111         # runs forward pass
--> 112         outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
    113         assert outputs[0].numel() == 1, (
    114             "Target not provided when necessary, cannot"

/usr/local/lib/python3.10/dist-packages/captum/_utils/common.py in _run_forward(forward_func, inputs, target, additional_forward_args)
    480     additional_forward_args = _format_additional_forward_args(additional_forward_args)
    481 
--> 482     output = forward_func(
    483         *(*inputs, *additional_forward_args)
    484         if additional_forward_args is not None

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-5-d2063671fd62> in forward(self, seq)
     28 
     29   def forward(self, seq):
---> 30         embeddings = self.embedding_layer(seq)
     31         output = self.transformer_encoder(embeddings).permute(1, 0, 2)
     32         output, hn = self.gru(output)

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/sparse.py in forward(self, input)
    160 
    161     def forward(self, input: Tensor) -> Tensor:
--> 162         return F.embedding(
    163             input, self.weight, self.padding_idx, self.max_norm,
    164             self.norm_type, self.scale_grad_by_freq, self.sparse)

/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2208         # remove once script supports set_grad_enabled
   2209         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2210     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
   2211 
   2212 

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)

so i tired using

input_ids.to(torch.long)

but nothing… can you help me please?

Are you re-assigning the transformed tensor to input_ids as it’s not an inplace operation or directly passing it to the method?
If it still fails, could you make your code snippet executable by removing the data dependency so that we could reproduce and debug the issue?

i reassigned the tensor. now i resolved the error by applying long() :

def forward(self, seq):
        seq = seq.long()   ####HERE
        embeddings = self.embedding_layer(seq)
        output = self.transformer_encoder(embeddings).permute(1, 0, 2)
        output, hn = self.gru(output)
        output = output.permute(1, 0, 2)
        hn = hn.permute(1, 0, 2)
       
        output = output.reshape(output.shape[0], -1)
        hn = hn.reshape(output.shape[0], -1)
        
        output = torch.cat([output, hn], 1)
        output = self.block_seq(output)
       
       
        return output

but a new error was raised:

RuntimeError                              Traceback (most recent call last)

<ipython-input-71-9c05c85df27b> in <cell line: 4>()
      2 
      3 
----> 4 attribution = ig.attribute(inputs=input_ids, baselines=baseline_input_ids, target=model_output(input_ids))           #(, baselines = baseline, target=0)

4 frames

/usr/local/lib/python3.10/dist-packages/captum/log/__init__.py in wrapper(*args, **kwargs)
     40             @wraps(func)
     41             def wrapper(*args, **kwargs):
---> 42                 return func(*args, **kwargs)
     43 
     44             return wrapper

/usr/local/lib/python3.10/dist-packages/captum/attr/_core/integrated_gradients.py in attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, internal_batch_size, return_convergence_delta)
    284             )
    285         else:
--> 286             attributions = self._attribute(
    287                 inputs=inputs,
    288                 baselines=baselines,

/usr/local/lib/python3.10/dist-packages/captum/attr/_core/integrated_gradients.py in _attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, step_sizes_and_alphas)
    349 
    350         # grads: dim -> (bsz * #steps x inputs[0].shape[1:], ...)
--> 351         grads = self.gradient_func(
    352             forward_fn=self.forward_func,
    353             inputs=scaled_features_tpl,

/usr/local/lib/python3.10/dist-packages/captum/_utils/gradient.py in compute_gradients(forward_fn, inputs, target_ind, additional_forward_args)
    117         # torch.unbind(forward_out) is a list of scalar tensor tuples and
    118         # contains batch_size * #steps elements
--> 119         grads = torch.autograd.grad(torch.unbind(outputs), inputs)
    120     return grads
    121 

/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched)
    301         return _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(grad_outputs_)
    302     else:
--> 303         return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    304             t_outputs, grad_outputs_, retain_graph, create_graph, t_inputs,
    305             allow_unused, accumulate_grad=False)  # Calls into the C++ engine to run the backward pass

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

i’m trying making the implementation data dependency free but is a bit complicated and for me is the first time

Could you try to narrow down the issue and post a minimal and executable code snippet reproducing the issue?

yes, here the data free code:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
import torch.nn.utils.rnn as rnn_utils
from captum.attr import IntegratedGradients

class xAInet(nn.Module):
  def __init__(self):
    super().__init__()
    
    self.hidden_dim = 25
    self.batch_size = 32
    self.embedding_dim = 512

    self.embedding_layer = nn.Embedding(24, self.embedding_dim, padding_idx=0)
    self.encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)

    self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=1)
    self.gru = nn.GRU(self.embedding_dim, self.hidden_dim, num_layers=2,
                      bidirectional=True, dropout=.2)
    
    self.block_seq = nn.Sequential(nn.Linear(15050, 2048),
                                   nn.BatchNorm1d(2048),
                                   nn.LeakyReLU(),
                                   nn.Linear(2048, 1024),
                                   nn.BatchNorm1d(1024),
                                   nn.LeakyReLU(),
                                   nn.Linear(1024, 256),
                                   nn.BatchNorm1d(256),
                                   nn.ReLU(),
                                   nn.Linear(256, 8),
                                   nn.Linear(8, 2),
                                   nn.Softmax(dim=1))
    
    
  def forward(self, seq):
        seq = seq.long()
        embeddings = self.embedding_layer(seq)
        output = self.transformer_encoder(embeddings).permute(1, 0, 2)
        output, hn = self.gru(output)
        output = output.permute(1, 0, 2)
        hn = hn.permute(1, 0, 2)
       
        output = output.reshape(output.shape[0], -1)
        hn = hn.reshape(output.shape[0], -1)
        
        output = torch.cat([output, hn], 1)
        output = self.block_seq(output)
       
       
        return output

  def train_model(self, seq):
    #with torch.no_grad():
        output = self.forward(seq)

        return output

def model_output(inputs):

    #inputs = inputs[0].unsqueeze(0)
  
    out = model(inputs)
    # Apply softmax to convert prediction scores to probabilities
    probabilities = torch.softmax(out, dim=1)
    

    # Get the predicted classes by selecting the class with the highest probability
    predicted_classes = torch.argmax(probabilities, dim=1)  
    return predicted_classes

def construct_input_and_baseline(text):

    max_length = 512
    #baseline_token_id = rnn_utils.pad_sequence()
    

    input_ids = []
    token_list = []
    
    aa_dict = {'A': 1, 'R': 2, 'N': 3, 'D': 4, 'C': 5, 'Q': 6, 'E': 7, 'G': 8, 'H': 9, 'I': 10,
               'L': 11, 'K': 12, 'M': 13, 'F': 14, 'P': 15, 'O': 16, 'S': 17, 'U': 18, 'T': 19,
               'W': 20, 'Y': 21, 'V': 22, 'X': 23}
    
    for char in text:
      if char in aa_dict:
        input_ids.append(aa_dict[char])
        token_list.append(char)

    baseline_token_id = 23
    baseline_input_ids = [baseline_token_id] * len(input_ids)

    input_ids_tensor = torch.tensor([input_ids], device='cpu')
    baseline_input_ids_tensor = torch.tensor([baseline_input_ids], device='cpu')

    return input_ids_tensor, baseline_input_ids_tensor, token_list

text = 'MSKSKMLVFKSKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKMSKSKMLVFKMSKSKMLVFKMSKSKMLVFKMSKSKMLVFK'

input_ids, baseline_input_ids, all_tokens = construct_input_and_baseline(text)

desired_length = 299
padded_sequences = [seq[:desired_length] if len(seq) >= desired_length else torch.cat((seq, torch.zeros(desired_length - len(seq)))) for seq in input_ids]

    
input_ids = rnn_utils.pad_sequence(padded_sequences, batch_first=True)

padded_sequences = [seq[:desired_length] if len(seq) >= desired_length else torch.cat((seq, torch.zeros(desired_length - len(seq)))) for seq in baseline_input_ids]

    
baseline_input_ids = rnn_utils.pad_sequence(padded_sequences, batch_first=True)

model = xAInet()
model.eval()

ig = IntegratedGradients(model, model.embedding_layer)


attribution = ig.attribute(inputs=input_ids, baselines=baseline_input_ids, target=model_output(input_ids))

Thanks for the update.
The issue seems to be created by detaching the input tensor by transforming it to a LongTensor before passing it to the embedding layer:

emb = nn.Embedding(10, 10)
x = torch.randint(0, 10, (1,)).float().requires_grad_()

out = emb(x.long())
torch.autograd.grad(out, x, grad_outputs=torch.ones_like(out))

You won’t be able to calculate the gradients in x since x.long() is detached already.

thank you for your help. so, in this case is not possible using IntegratedGradients, because it needs the expected tensor for argument #1 ‘indices’ is a long or int scalar types, is it right?

RuntimeError                              Traceback (most recent call last)

<ipython-input-16-1163c9a7c017> in <cell line: 6>()
      4 
      5 
----> 6 attribution, delta = ig.attribute(inputs=input_ids, baselines=baseline_input_ids, target=model_output(input_ids), return_convergence_delta=True, internal_batch_size=1)

5 frames

<ipython-input-5-5adf40f3a278> in model_output(inputs)
      3     #inputs = inputs[0].unsqueeze(0)
      4 
----> 5     out = model(inputs)
      6     # Apply softmax to convert prediction scores to probabilities
      7     probabilities = torch.softmax(out, dim=1)

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-4-f9b43476cb2d> in forward(self, seq)
     38   def forward(self, seq):
     39         #seq = seq.long()
---> 40         embeddings = self.embedding_layer(seq)
     41         output = self.transformer_encoder(embeddings).permute(1, 0, 2)
     42         output, hn = self.gru(output)

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/sparse.py in forward(self, input)
    160 
    161     def forward(self, input: Tensor) -> Tensor:
--> 162         return F.embedding(
    163             input, self.weight, self.padding_idx, self.max_norm,
    164             self.norm_type, self.scale_grad_by_freq, self.sparse)

/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2208         # remove once script supports set_grad_enabled
   2209         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2210     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
   2211 
   2212 

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)

anyway, i resolved the issue by using LayerIntegratedGradients

from captum.attr import LayerIntegratedGradients

ig = LayerIntegratedGradients(model, model.embedding_layer)


attribution, delta = ig.attribute(inputs=input_ids, baselines=baseline_input_ids, target=model_output(input_ids), return_convergence_delta=True, internal_batch_size=1)  

I’m not familiar enough with IntegratedGradients, but you are generally not able to calculate gradients for integer types as my code snippet shows.