Callback for Integrated Gradients

I am utilising PyTorch_lightning for Training and Testing and I have tried to incorporate Integradted Gradients Calculations in Test and as well of predict step. It also throws RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn, Clueless how to tackle this, I am planning to implement IG extraction using Captum through a Callback function. Is it possible ?

I am still stuck at the same error.

  File "/home/neelamlab/anaconda3/envs/torch2.5.1/lib/python3.12/site-packages/torch/autograd/__init__.py", line 496, in grad
    result = _engine_run_backward(
             ^^^^^^^^^^^^^^^^^^^^^
  File "/home/neelamlab/anaconda3/envs/torch2.5.1/lib/python3.12/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

CallBack implementation

class IntegratedGradientsCallback(Callback):
    def __init__(self, output_dir):
        super().__init__()
        self.output_dir = output_dir

    @torch.enable_grad()
    def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        inputs, labels, label_info = batch  # Unpack batch
        model = pl_module
        regression_model = RegressionHeadWrapper(model)
        outputs = regression_model(inputs)
        print(outputs)
        
        attr_map = getInputAttributions(regression_model, inputs)  # Replace with your IG computation function
        print(attr_map)

My model is a 2 headed model so I have a custom wrapper on it while getting predictions and performing integradted gradients

class RegressionHeadWrapper(nn.Module):
    def __init__(self, model):
        super(RegressionHeadWrapper, self).__init__()
        self.model = model

    def forward(self, x):
        x_age, _ = self.model(x)  # Extract only the regression output
        return x_age

IG call:

def getInputAttributions(model, input_tensor):
    input_tensor.requires_grad_()
    ig = IntegratedGradients(model)
    attr = ig.attribute(input_tensor, n_steps=10, internal_batch_size=config['batch_size'])
    return attr