torch.jit.trace() python function accepts a
nn.Module and an example input, and outputs a
Is there a way to extract out the expected input shape from the
torch.jit.ScriptModule? In other words, if the
trace() call was passed as its second argument a Tensor of shape
(1, 2, 3, 4), is there a way to recover
(_, 2, 3, 4) from the
torch.jit.ScriptModule outputted by
If you do something like this
x = torch.rand((4, 4))
fn_t = torch.jit.trace(fn, (x,))
Then I believe the printed graph will show shape info. You can also programmatically access the input types, see some minimal examples here: pytorch/test_python_bindings.py at master · pytorch/pytorch · GitHub
However, this is an implementation detail. The shapes shown in the graph are just an artifact of the tracing process which could change in the future.
Thanks. If this is the best way to get the input shape in python, I am guessing that the c++
torch::jit::script module won’t offer much better options. In my context, I was hoping to save the python
torch.jit.ScriptModule to disk, load it in c++, and infer the input shape at c++ runtime.