Using captum with nn.Embedding getting RuntimeError

I am using captum library and getting following error. Here is complete code to reproduce the error. Appreciate if someone could help me here.

Thank in advance. -OneQ-

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from captum.attr import IntegratedGradients

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

vocab_size = 1024
embedding_dim = 32
seq_len = 128
num_classes = 5
hidden_dim = 256

class predictor(nn.Module):
    def __init__(self):
        super().__init__()
        self.seq_len = seq_len
        self.num_classes = num_classes
        self.hidden_dim = hidden_dim 
        self.vocab_size, self.embedding_dim = vocab_size, embedding_dim

        self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim)
        self.linear = nn.Linear(self.seq_len*self.embedding_dim, self.num_classes)

    def forward(self, x):
        x = self.embedding(x.long())
        x = x.reshape(-1, self.seq_len*self.embedding_dim)
        x = F.relu(self.linear(x))
        return x

class wrapper_predictor(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        x = self.model(x)
        x = F.softmax(x, dim=1)
        return x
    
indexes = torch.Tensor(np.random.randint(0, vocab_size, (seq_len))).to(device)

model = predictor().to(device)
wrapper_model = wrapper_predictor(model).to(device)

ig = IntegratedGradients(wrapper_model)
attributions, delta = ig.attribute(inputs=indexes, target=0, n_steps=1, return_convergence_delta=True)

I resolved the issue with LayerIntegratedGradients.

Here is the link to read more to know other possible solutions. Captum · Model Interpretability for PyTorch

Sample code

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from captum.attr import IntegratedGradients, LayerIntegratedGradients
from torchsummary import summary

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

vocab_size = 1024
embedding_dim = 1
seq_len = 128
num_classes = 5
hidden_dim = 256

class predictor(nn.Module):
	def __init__(self):
		super(predictor, self).__init__()
		self.seq_len = seq_len
		self.num_classes = num_classes
		self.hidden_dim = hidden_dim 
		self.vocab_size, self.embedding_dim = vocab_size, embedding_dim

		self.embedding = nn.Sequential(
			nn.Embedding(self.vocab_size, self.embedding_dim),
		)
		self.embedding.weight = torch.randn((self.vocab_size, self.embedding_dim), requires_grad=True)
		self.fc = nn.Sequential(
			nn.Linear(self.seq_len*self.embedding_dim, self.hidden_dim, device=device, bias=False),
			nn.Linear(self.hidden_dim, self.num_classes, device=device, bias=False),
		)
	def forward(self, x):
		x = self.embedding(x.long())
		x = x.view(-1, self.seq_len*self.embedding_dim)
		x = self.fc(x)
		return x

class wrapper_predictor(nn.Module):
	def __init__(self, model):
		super().__init__()
		self.model = model
	def forward(self, x):
		x = self.model(x)
		x = F.softmax(x, dim=1)
		return x

model = predictor().to(device)

indexes = torch.Tensor(np.random.randint(0, vocab_size, (seq_len))).to(device)
input_size = indexes.shape
summary(model=model, input_size=input_size, batch_size=-1, device='cuda')

wrapper_model = wrapper_predictor(model).to(device)

lig = LayerIntegratedGradients(model, model.embedding)
attributions, delta = lig.attribute(inputs=indexes, target=0, n_steps=1, return_convergence_delta=True)