Can I perform jit.trace only on specific layers of a model?


I received an JIT version of encoder file from my colleague, and I was appending my own decoder layer (i.e. a simple MLP for classification purpose) on top of it, and I wrapped everything as one module.

class Combined_Net(nn.Module):
def init(self):
super(Combined_net, self).init()
self.encoder = torch.jit.load(“myfriends_model.pth”)
self.decoder = torch.nn.Linear(in_feature =, out_feature =)
def forward(self, x):
inter_feat = self.encoder(x)
output = self.decoder(inter_feat)
return output

The code looks like something like above. But the following command crashed:

model_ts = torch.jit.trace(my_combined_net, inputs_ts, check_trace=False)

This should be caused by the fact that I was performing jit.trace on something which was already traced.
So how can I perform jit.trace only on the decoder part of my network, but generate one single model_ts with JIT of both encoder and decoder?

Or if I do jit.trace only on decoder part, how can I combine encoder_jit and decoder_jit into one module?


I’m unsure why it failed as a warning should be raised if you are trying to trace an already traced module again:

model = nn.Linear(10, 10).cuda()
x = torch.randn(1, 10).cuda()
model = torch.jit.trace(model, x)
model = torch.jit.trace(model, x)
# UserWarning: The input to trace is already a ScriptModule, tracing it is a no-op. Returning the object as is.

With that being said, note that TorchScript is in maintenance mode and the recommended approach is to use torch.compile instead.