Hi,
I received an JIT version of encoder file from my colleague, and I was appending my own decoder layer (i.e. a simple MLP for classification purpose) on top of it, and I wrapped everything as one module.
class Combined_Net(nn.Module):
def init(self):
super(Combined_net, self).init()
self.encoder = torch.jit.load(“myfriends_model.pth”)
self.decoder = torch.nn.Linear(in_feature =, out_feature =)
def forward(self, x):
inter_feat = self.encoder(x)
output = self.decoder(inter_feat)
return output
The code looks like something like above. But the following command crashed:
model_ts = torch.jit.trace(my_combined_net, inputs_ts, check_trace=False)
This should be caused by the fact that I was performing jit.trace on something which was already traced.
So how can I perform jit.trace only on the decoder part of my network, but generate one single model_ts with JIT of both encoder and decoder?
Or if I do jit.trace only on decoder part, how can I combine encoder_jit and decoder_jit into one module?
Thanks.