Obscure error message when trying to export a compiled function

I am trying to compile and save a function which is using Pytorch as a way to do some matrix computation on GPU.

def grad(self,input : torch.Tensor, matmul_result : torch.Tensor = None) -> torch.Tensor:
        import os
        if matmul_result is not None:
            inputs = [matmul_result]
            if not os.path.isfile("trt_model.ts"):
                self.grad_aux_compile = torch.compile(self.get_grad_aux_to_compile, backend="torch_tensorrt")
                trt_exp_program = torch.export.export(self.get_grad_aux_compile, args=tuple(inputs))
                torch.export.save(trt_exp_program, "trt_model.ep")
            elif self.grad_aux_compile is None:
                self.grad_aux_compile = torch.jit.load("trt_model.ep")   
            tmp = self.grad_aux_compile(*inputs)
        else:
            tmp = self.get_grad_aux(input, matmul_result)
        gradient_size = tmp.size()
        if (len(gradient_size) > 1):
            n = gradient_size[1]
        else:
            n = gradient_size[0]
        return self.add_dual_pred_tensor(tmp, None, 1.0/n)

When the code is running, the following line of code is causing, the next error:

trt_exp_program = torch.export.export(self.get_grad_aux_compile, args=tuple(inputs))

AssertionError: graph-captured input #2, of type <class ‘torch.Tensor’>, is not among original inputs of types: (<class ‘torch.Tensor’>)

I do not know where to start to fix this. I use the same input as I use to run the command so I am quite lost. Is there a way to see the graph ? I would like to understand what is the expected input.

Can it be because the method is using a class attribute and is not static ? The method can be runned compiled.

I also tried torch.jit which did not improve my performance.
I also saw that there are other methods if you compile module and not function.

Is there another way to export the compiled function?