The torch.jit.trace() python function accepts a nn.Module and an example input, and outputs a torch.jit.ScriptModule.
Is there a way to extract out the expected input shape from the torch.jit.ScriptModule? In other words, if the trace() call was passed as its second argument a Tensor of shape (1, 2, 3, 4), is there a way to recover (_, 2, 3, 4) from the torch.jit.ScriptModule outputted by trace()?
Thanks. If this is the best way to get the input shape in python, I am guessing that the c++ torch::jit::script module won’t offer much better options. In my context, I was hoping to save the python torch.jit.ScriptModule to disk, load it in c++, and infer the input shape at c++ runtime.