I have a model that has a reshape operation inside it (essentially to do something like group normalisation, but different). I reshape such that the channel dimension becomes two channels, sum over one of them, divide by it and then reshape it back.
This works fine while training and testing, but when I jit.trace the model I get a malformed model, where the ‘self’ gets overwritten (see the ‘self=…’ line). As seen here in part of the code.py:
x_70 = torch.add_(x_69, input_65, alpha=1)
_288 = ops.prim.NumToTensor(torch.size(x_70, 0))
_289 = int(_288)
_290 = int(_288)
self = ops.prim.NumToTensor(torch.size(x_70, 1))
_291 = int(self)
_292 = ops.prim.NumToTensor(torch.size(x_70, 2))
_293 = int(_292)
_294 = int(_292)
_295 = ops.prim.NumToTensor(torch.size(x_70, 3))
_296 = int(_295)
_297 = int(_295)
_298 = ops.prim.NumToTensor(torch.size(x_70, 4))
_299 = int(_298)
_300 = int(_298)
_301 = [_290, int(torch.div(self, CONSTANTS.c0)), 4, _294, _297, _300]
x_71 = torch.reshape(x_70, _301)
When I replace ‘self’ with ‘self_19’ it’s allright, and I can load the model.
However I also have issues exporting in ‘onnx’ which complains about the reshape operation.
And I have troubles then running the model in the C++ API, the model does not work on GPU on linux (but works on CPU on LINUX, and both GPU and CPU on Windows).
I have a feeling all these problems are related, is there something known about the reshape operation that causes this?