I’ve trained a style transfer model based on this implementation.
As far as I can tell the following defines the shape of their input and model:
# Get Style Features
imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).to(device)
style_image = utils.load_image(STYLE_IMAGE_PATH)
style_tensor = utils.itot(style_image).to(device)
style_tensor = style_tensor.add(imagenet_neg_mean)
B, C, H, W = style_tensor.shape
style_features = VGG(style_tensor.expand([BATCH_SIZE, C, H, W]))
style_gram = {}
for key, value in style_features.items():
style_gram[key] = utils.gram(value)
The trained model works on any shape of image input.
I’ve been able to successfully transfer this to ONNX, and use it from there.
Unfortunately only with static H and W using the following code:
prior = torch.load("pth/" + name + ".pth");
model = transformer.TransformerNetwork() # [0]
model.load_state_dict(prior)
dummy_input = torch.randn(1, 3, 640, 480). # 640 x 480 is an arbitrary static H W
torch.onnx.export(model, dummy_input, "onnx/" + name + ".onnx", verbose=True,
input_names=["input_image"],
output_names=["stylized_image"])
[0]: transformer.py
I tried to use:
dynamic_axes={'input_image': {2:'height', 3:'width'}, 'stylized_image': {2:'height', 3:'width'}}
Both labeled and unlabeled.
Unfortunately I don’t know what to change the: dummy_input = torch.randn(1, 3, 640, 480)
values to, or if they should remain as they are.
What would be the proper method for describing this model in order to get a good ONNX conversion?
Thanks!