How Are Sizes Tracked When Tracing or Scripting?

When compiling to TorchScript either with tracing or scripting, I often have problems with operations that depend explicitly on tensor sizes. Sometimes sizes get hardcoded as constants, breaking compatibility with variable batch sizes. Other times PyTorch complains that I’m converting to torch integers to Python ints, which will be traced as constants.

I think these problem are based on my poor understanding of how torch.Size objects work, and how they are traced. This pull request mentions the issue and says that torch.Size objects basically are tuples of torch ints. The use of PyTorch integers is what allows for correct tracking. But when I check, the data type of each entry of a tensor size is an int.

How do torch.Size objects work internally? Of should I manipulate them in order to get correct TorchScript compiling?

Additional note: I know that tracing isn’t guaranteed to compile data-dependent operations correctly, but:

  1. often tracing is more convenient than scripting.
  2. Manipulations of tensor sizes should be recorded correctly by tracing

I’ve added a MWE to showcase the problem I’m facing. Define the following toy class:

class Foo(nn.Module):
    """Toy class that plays with tensor shape to showcase tracing issue.
    """
    def __init__(self):
        nn.Module.__init__(self)

    def forward(self, x):
        new_shape = (x.shape[0], 2*x.shape[1])  # incriminated instruction
        x2 = torch.empty(size=new_shape)
        x2[:, ::2] = x
        x2[:, 1::2] = x + 1
        return x2

and run the test code:

x = torch.randn((3, 5))  # create example input

foo = Foo()
traced_foo = torch.jit.trace(foo, x)  # trace
print(traced_foo(x).shape)  # obviously this works
print(traced_foo(x[:, :4]).shape)  # but fails with a different shape!

So, here the problem is that the tensor sized are hardcoded as constants, instead of being traced. How can I overcome this issue?

I got an answer on StackOverflow.