Trying to convert a pytorch model to onnx, i have create a dummy_input of shape (1, 3, 224, 224) and while running the below code i get following error
RuntimeError: Given groups=1, weight of size 32 4 3 3, expected input[1, 3, 224, 224] to have 4 channels, but got 3 channels instead
dummy_input.shape = (1, 3, 224, 224)
Here is the code :
input_shape = (3, 224, 224)
model_onnx_path = "torch_model.onnx"
model.train(False)
# Export the model to an ONNX file
dummy_input = Variable(torch.randn(1, *input_shape))
print('dummy_input shape : ', str(dummy_input.shape))
output = torch.onnx.export(model,
dummy_input,
model_onnx_path,
verbose=False)
print("Export of torch_model.onnx complete!")
Is there anything wrong in this code, please correct me.