I don’t understand the logic behind the parsing of the squeeze
function when doing an ONNX export. For example, in the following sample script, the exported operators behavior changes depending on whether or not the GridSample
function is inserted.
I am seeking an explanation or a link to a reference document regarding the above.
import torch
import torch.nn as nn
import torch.nn.functional as F
import onnx
from onnxruntime.tools import pytorch_export_contrib_ops
pytorch_export_contrib_ops.register()
class Model(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
a = x.new_zeros((x.size(0), x.size(1), x.size(2)))
y = F.pad(x, [2, 2])
x = F.interpolate(a[None], size=y.shape[-2:])
# x = torch.nn.functional.grid_sample(x, torch.randn(1, 30, 30, 2))
x = x.squeeze(0)
return x
def deploy():
M = Model().to("cpu").eval()
X = torch.randn(20, 16, 50, 100)
traced = torch.jit.trace(M, X)
torch.onnx.export(
traced,
X,
"tmp.onnx",
verbose=False,
do_constant_folding=True,
opset_version=16,
input_names=["input"],
)
model = onnx.load("tmp.onnx")
print(onnx.helper.printable_graph(model.graph))
if __name__ == "__main__":
deploy()
Thank you.