Deduce input shape from torch.jit.ScriptModule

The torch.jit.trace() python function accepts a nn.Module and an example input, and outputs a torch.jit.ScriptModule.

Is there a way to extract out the expected input shape from the torch.jit.ScriptModule? In other words, if the trace() call was passed as its second argument a Tensor of shape (1, 2, 3, 4), is there a way to recover (_, 2, 3, 4) from the torch.jit.ScriptModule outputted by trace()?

1 Like

If you do something like this

import torch

def fn(x):
    return x.relu()

x = torch.rand((4, 4))
fn_t = torch.jit.trace(fn, (x,))

Then I believe the printed graph will show shape info. You can also programmatically access the input types, see some minimal examples here: pytorch/ at master · pytorch/pytorch · GitHub

However, this is an implementation detail. The shapes shown in the graph are just an artifact of the tracing process which could change in the future.

Thanks. If this is the best way to get the input shape in python, I am guessing that the c++ torch::jit::script module won’t offer much better options. In my context, I was hoping to save the python torch.jit.ScriptModule to disk, load it in c++, and infer the input shape at c++ runtime.

1 Like