Fail to convert PyTorch module to TorchScript

I’ve created a model with a forward function like this:

class Net(nn.Module):
...
  def forward(self, num_nodes, num_feats, nodes):
          features = nn.Embedding(num_nodes, num_feats)
          features.weight = nn.Parameter(torch.FloatTensor(feat_data), requires_grad=False)

then save that model using

traced_script_module = torch.jit.script(net)
traced_script_module.save(model_path1)

I have train model successfully, but get this error when save the model.

NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:...
'Embedding' is being compiled since it was called from '__torch__.___torch_mangle_0.Net.forward'
at <ipython-input-5-501dbaacc7a5>:42:8
    def forward(self, num_nodes, num_feats, nodes):
        
        features = nn.Embedding(num_nodes, num_feats)
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        features.weight = nn.Parameter(torch.FloatTensor(feat_data), requires_grad=False)

And pytorch’s version is 1.3
What is the best way to handle this?
Any help appreciated!!!

1 Like

You should move the initialization of submodules out into __init__ rather than in your model’s forward. This will also likely improve performance since costly parameter allocations only need to happen once instead of on each run of your model’s forward.

class M(nn.Module):
    def __init__(self, num_nodes, num_feats):
        self.features = nn.Embedding(num_nodes, num_feats)
        self.features.weight = nn.Parameter(torch.FloatTensor(feat_data), requires_grad=False)
    
    def forward(self, nodes):
        result = self.features(...)
        ...

model = M()
script_model = torch.jit.script(model)
script_model.save("script_model.pt")

When classes are instantiated in TorchScript, the entire class must be compatible with the TorchScript compiler (details), which is not the case for most nn.Modules. However, if nn.Modules are saved on self in __init__, only the methods that are actually used in the forward of your model M need to be compatible with the compiler (which should work for any module in nn except for these 3).

1 Like

Got it!
Thank you for the answers.