When compiling to TorchScript either with tracing or scripting, I often have problems with operations that depend explicitly on tensor sizes. Sometimes sizes get hardcoded as constants, breaking compatibility with variable batch sizes. Other times PyTorch complains that I’m converting to torch integers to Python
ints, which will be traced as constants.
I think these problem are based on my poor understanding of how
torch.Size objects work, and how they are traced. This pull request mentions the issue and says that
torch.Size objects basically are tuples of torch ints. The use of PyTorch integers is what allows for correct tracking. But when I check, the data type of each entry of a tensor size is an int.
torch.Size objects work internally? Of should I manipulate them in order to get correct TorchScript compiling?
Additional note: I know that tracing isn’t guaranteed to compile data-dependent operations correctly, but:
- often tracing is more convenient than scripting.
- Manipulations of tensor sizes should be recorded correctly by tracing
I’ve added a MWE to showcase the problem I’m facing. Define the following toy class:
class Foo(nn.Module): """Toy class that plays with tensor shape to showcase tracing issue. """ def __init__(self): nn.Module.__init__(self) def forward(self, x): new_shape = (x.shape, 2*x.shape) # incriminated instruction x2 = torch.empty(size=new_shape) x2[:, ::2] = x x2[:, 1::2] = x + 1 return x2
and run the test code:
x = torch.randn((3, 5)) # create example input foo = Foo() traced_foo = torch.jit.trace(foo, x) # trace print(traced_foo(x).shape) # obviously this works print(traced_foo(x[:, :4]).shape) # but fails with a different shape!
So, here the problem is that the tensor sized are hardcoded as constants, instead of being traced. How can I overcome this issue?