I am trying to compile and save a function which is using Pytorch as a way to do some matrix computation on GPU.
def grad(self,input : torch.Tensor, matmul_result : torch.Tensor = None) -> torch.Tensor:
import os
if matmul_result is not None:
inputs = [matmul_result]
if not os.path.isfile("trt_model.ts"):
self.grad_aux_compile = torch.compile(self.get_grad_aux_to_compile, backend="torch_tensorrt")
trt_exp_program = torch.export.export(self.get_grad_aux_compile, args=tuple(inputs))
torch.export.save(trt_exp_program, "trt_model.ep")
elif self.grad_aux_compile is None:
self.grad_aux_compile = torch.jit.load("trt_model.ep")
tmp = self.grad_aux_compile(*inputs)
else:
tmp = self.get_grad_aux(input, matmul_result)
gradient_size = tmp.size()
if (len(gradient_size) > 1):
n = gradient_size[1]
else:
n = gradient_size[0]
return self.add_dual_pred_tensor(tmp, None, 1.0/n)
When the code is running, the following line of code is causing, the next error:
trt_exp_program = torch.export.export(self.get_grad_aux_compile, args=tuple(inputs))
AssertionError: graph-captured input #2, of type <class ‘torch.Tensor’>, is not among original inputs of types: (<class ‘torch.Tensor’>)
I do not know where to start to fix this. I use the same input as I use to run the command so I am quite lost. Is there a way to see the graph ? I would like to understand what is the expected input.
Can it be because the method is using a class attribute and is not static ? The method can be runned compiled.
I also tried torch.jit which did not improve my performance.
I also saw that there are other methods if you compile module and not function.
Is there another way to export the compiled function?