Trouble with reshape and jit/trace


(Marijn Stollenga) #1

I have a model that has a reshape operation inside it (essentially to do something like group normalisation, but different). I reshape such that the channel dimension becomes two channels, sum over one of them, divide by it and then reshape it back.

This works fine while training and testing, but when I jit.trace the model I get a malformed model, where the ‘self’ gets overwritten (see the ‘self=…’ line). As seen here in part of the code.py:

x_70 = torch.add_(x_69, input_65, alpha=1)
_288 = ops.prim.NumToTensor(torch.size(x_70, 0))
_289 = int(_288)
_290 = int(_288)
self = ops.prim.NumToTensor(torch.size(x_70, 1))
_291 = int(self)
_292 = ops.prim.NumToTensor(torch.size(x_70, 2))
_293 = int(_292)
_294 = int(_292)
_295 = ops.prim.NumToTensor(torch.size(x_70, 3))
_296 = int(_295)
_297 = int(_295)
_298 = ops.prim.NumToTensor(torch.size(x_70, 4))
_299 = int(_298)
_300 = int(_298)
_301 = [_290, int(torch.div(self, CONSTANTS.c0)), 4, _294, _297, _300]
x_71 = torch.reshape(x_70, _301)

When I replace ‘self’ with ‘self_19’ it’s allright, and I can load the model.
However I also have issues exporting in ‘onnx’ which complains about the reshape operation.
And I have troubles then running the model in the C++ API, the model does not work on GPU on linux (but works on CPU on LINUX, and both GPU and CPU on Windows).

I have a feeling all these problems are related, is there something known about the reshape operation that causes this?


(Michael Suo) #2

Thanks for the report! Seems like it may be a problem with our serialization code. Could you provide a small module/script that reproduces the problem so that we can investigate?


(Marijn Stollenga) #3

Ok to reproduce it I have a reshaping operation. It is essential that ‘view’ gets a shape that is calculated partly from another shape, as that seems to cause the trouble. I added a linear layer after that to make sure ‘self’ is used again and it fails:

#!/usr/bin/ipython3

import torch

class Example(torch.nn.Module):
def init(self):
super(Example, self).init()

def forward(self, x):
    s = x.shape
    b = x.view(x.shape[0],x.shape[1]//2,2)
    accum = b.sum(1, keepdim=True)
    b = b * accum
    return b.view(*s)

class ExampleNested(torch.nn.Module):
def init(self):
super(ExampleNested, self).init()
self.ex = Example()
self.lin = torch.nn.Linear(4,4)

def forward(self, x):
    x = self.ex(x)
    x = self.lin(x)
    return x

a = torch.randn(4,4)

example = ExampleNested()
traced = torch.jit.trace(example, a)
traced.save(“trace.tmp”)

This gives me:

op_version_set = 0
def forward(self,
x: Tensor) -> Tensor:
_0 = ops.prim.NumToTensor(torch.size(x, 0))
_1 = int(_0)
_2 = ops.prim.NumToTensor(torch.size(x, 1))
_3 = int(_2)
_4 = ops.prim.NumToTensor(torch.size(x, 0))
_5 = int(_4)
self = ops.prim.NumToTensor(torch.size(x, 1))
_6 = [_5, int(torch.div(self, CONSTANTS.c0)), 2]
b_1 = torch.view(x, _6)
accum = torch.sum(b_1, [1], True)
b = torch.mul(b_1, accum)
input = torch.view(b, [_1, _3])
_7 = torch.addmm(self.lin.bias, input, torch.t(self.lin.weight), beta=1, alpha=1)
return _7


(Marijn Stollenga) #4

Btw to follow up, I’ll add that I only see this problem happening with the last operation being done. So it doesn’t matter how many of such layers I connect together, only the last operation gets the wrong ‘self’ naming without the number at the end.


(Michael Suo) #5

Thanks for report. I’ve filed a GH issue here and we will update you there.


(Marijn Stollenga) #6

I see from the issue that the bugfix is in 1.0.1! I’ll try it now, since it seems the pip package is updated.