When I tried to export my trained pytorch model to ONNX format, I encounter the error: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
After searching on board, I found multiple cases that results in same error, but I didn’t find a solution suitable for my case.
Here is my code:
# Define a function: input raw data => preprocessing & model inference => prediction
from PIL import Image
# Preprocess image for PyTorch data architecture
from torchvision import transforms
data_preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(0.456, 0.225)
])
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
def process_for_model(img):
'''
Process image array for PyTorch model
img: input tensor of an image
return: Tensor image prepared to input to the model
'''
#img = Image.fromarray(img).convert('RGB')
img = data_preprocess(img) # remember to do data preprocessing as in the training stage!! This strongly influence the testing performance
img = img.to(device)
img = img.view(1, 3, 889, 929)
return img
def pipeline(input_batch, preprocess, inference):
processed_data = preprocess(input_batch)
output = inference(processed_data)
return output
def final_pipeline(input_batch):
prediction = pipeline(input_batch)
return prediction
class mypipeline(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input_batch):
input_batch = input_batch.cpu().detach().numpy().astype(np.float32)
processed_data = process_for_model(input_batch)
print("processed data:", processed_data)
self.output = pre_model(processed_data).cpu().detach()
return self.output
input_shape = (889, 929, 3)
dummy_input = torch.randint(0, 255, size = (889, 929, 3), device = torch.device("cuda:0"))
pre_model.eval()
pipeline = mypipeline()
pipeline.eval()
with torch.no_grad():
torch.onnx.export(pipeline,
dummy_input,
"txt_overlap_pip.onnx",
verbose = False,
)
I’ve numpy array in the workflow, which I’m not sure if it would be the cause of problem.
Any advice is appreciated, thank you!