How to change the last layer(s) of a traced torchscript model

Hello, I am wondering if it is possible to change the last layer (or more) of a loaded torchscript model?

This would be useful for changing the number of categories a torchscript model could predict after training again.

Right now I get an odd error when trying to overwrite the last layer manually.

>> model = torch.jit.load("model_cpu.pth")
>> model.last_layer

RecursiveScriptModule(original_name=Linear)

>> new_last_layer = torch.jit.script(torch.nn.Linear(a, b))
>> new_last_layer

RecursiveScriptModule(original_name=Linear)

>> model.last_layer = new_last_layer
>> a.forward(torch.rand([x, y, z ... ]))

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: isTensor() INTERNAL ASSERT FAILED at ../aten/src/ATen/core/ivalue_inl.h:86, please report a bug to PyTorch. Expected Tensor but got Bool
The above operation failed in interpreter.
Traceback (most recent call last):
/Users/****/venv/lib/python3.7/site-packages/torch/nn/functional.py(1370): linear
/Users/****/venv/lib/python3.7/site-packages/torch/nn/modules/linear.py(87): forward
/Users/****/venv/lib/python3.7/site-packages/torch/nn/modules/module.py(525): _slow_forward
/Users/****/venv/lib/python3.7/site-packages/torch/nn/modules/module.py(539): __call__
/Users/****/*****/ops/models.py(277): forward
/Users/****/venv/lib/python3.7/site-packages/torch/nn/modules/module.py(525): _slow_forward
/Users/****/venv/lib/python3.7/site-packages/torch/nn/modules/module.py(539): __call__
/Users/****/venv/lib/python3.7/site-packages/torch/jit/__init__.py(997): trace_module
/Users/****/venv/lib/python3.7/site-packages/torch/jit/__init__.py(858): trace
convert.py(91): main
convert.py(110): <module>
Serialized   File "code/__torch__/ops/models.py", line 1052
    input153 = torch.flatten(x31, 1, -1)
    input154 = torch.dropout(input153, 0.5, False)
    base_out = torch.addmm(bias52, input154, torch.t(weight105), beta=1, alpha=1)
                                             ~~~~~~~ <--- HERE
    _512 = ops.prim.NumToTensor(torch.size(base_out, 1))
    input_tensor = torch.view(base_out, [-1, 8, int(_512)])

@jhhurwitz I don’t think there’s an easy way of doing that currently. See this issue for more context: https://github.com/pytorch/pytorch/issues/21064

@eellison Is it on the roadmap to add an easy way to do this? Looks like the issue has stalled out a little bit.

In the meantime, even an example of the hard way to do this would be useful if you happened to know of any.

Thanks!